diff --git a/data/test/images/face_emotion.jpg b/data/test/images/face_emotion.jpg new file mode 100644 index 00000000..54f22280 --- /dev/null +++ b/data/test/images/face_emotion.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:712b5525e37080d33f62d6657609dbef20e843ccc04ee5c788ea11aa7c08545e +size 123341 diff --git a/data/test/images/face_human_hand_detection.jpg b/data/test/images/face_human_hand_detection.jpg new file mode 100644 index 00000000..f94bb547 --- /dev/null +++ b/data/test/images/face_human_hand_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8fddc7be8381eb244cd692601f1c1e6cf3484b44bb4e73df0bc7de29352eb487 +size 23889 diff --git a/data/test/images/product_segmentation.jpg b/data/test/images/product_segmentation.jpg new file mode 100644 index 00000000..c188a69e --- /dev/null +++ b/data/test/images/product_segmentation.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a16038f7809127eb3e03cbae049592d193707e095309daca78f7d108d67fe4ec +size 108357 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index f94d4103..b7194b8d 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -40,6 +40,9 @@ class Models(object): ulfd = 'ulfd' video_inpainting = 'video-inpainting' hand_static = 'hand-static' + face_human_hand_detection = 'face-human-hand-detection' + face_emotion = 'face-emotion' + product_segmentation = 'product-segmentation' # EasyCV models yolox = 'YOLOX' @@ -179,9 +182,16 @@ class Pipelines(object): movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' shop_segmentation = 'shop-segmentation' video_inpainting = 'video-inpainting' + pst_action_recognition = 'patchshift-action-recognition' hand_static = 'hand-static' + face_human_hand_detection = 'face-human-hand-detection' + face_emotion = 'face-emotion' + product_segmentation = 'product-segmentation' # nlp tasks + automatic_post_editing = 'automatic-post-editing' + translation_quality_estimation = 'translation-quality-estimation' + domain_classification = 'domain-classification' sentence_similarity = 'sentence-similarity' word_segmentation = 'word-segmentation' part_of_speech = 'part-of-speech' diff --git a/modelscope/metrics/image_portrait_enhancement_metric.py b/modelscope/metrics/image_portrait_enhancement_metric.py index b8412b9e..5a81e956 100644 --- a/modelscope/metrics/image_portrait_enhancement_metric.py +++ b/modelscope/metrics/image_portrait_enhancement_metric.py @@ -1,3 +1,5 @@ +# Part of the implementation is borrowed and modified from BasicSR, publicly available at +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py from typing import Dict import numpy as np diff --git a/modelscope/models/cv/action_recognition/__init__.py b/modelscope/models/cv/action_recognition/__init__.py index 7bdee0cd..5e9dc310 100644 --- a/modelscope/models/cv/action_recognition/__init__.py +++ b/modelscope/models/cv/action_recognition/__init__.py @@ -7,11 +7,13 @@ if TYPE_CHECKING: from .models import BaseVideoModel from .tada_convnext import TadaConvNeXt + from .temporal_patch_shift_transformer import PatchShiftTransformer else: _import_structure = { 'models': ['BaseVideoModel'], 'tada_convnext': ['TadaConvNeXt'], + 'temporal_patch_shift_transformer': ['PatchShiftTransformer'] } import sys diff --git a/modelscope/models/cv/action_recognition/temporal_patch_shift_transformer.py b/modelscope/models/cv/action_recognition/temporal_patch_shift_transformer.py new file mode 100644 index 00000000..46596afd --- /dev/null +++ b/modelscope/models/cv/action_recognition/temporal_patch_shift_transformer.py @@ -0,0 +1,1198 @@ +# Part of the implementation is borrowed and modified from Video Swin Transformer, +# publicly available at https://github.com/SwinTransformer/Video-Swin-Transformer + +from abc import ABCMeta, abstractmethod +from functools import lru_cache, reduce +from operator import mul + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import torchvision.transforms as T +from einops import rearrange +from timm.models.layers import DropPath, Mlp, trunc_normal_ + +from modelscope.models import TorchModel + + +def normal_init(module, mean=0., std=1., bias=0.): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.normal_(module.weight, mean, std) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def window_partition(x, window_size): + """ window_partition function. + Args: + x: (B, D, H, W, C) + window_size (tuple[int]): window size + + Returns: + windows: (B*num_windows, window_size*window_size, C) + """ + B, D, H, W, C = x.shape + x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], + window_size[1], W // window_size[2], window_size[2], C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, + 7).contiguous().view(-1, reduce(mul, window_size), C) + return windows + + +def window_reverse(windows, window_size, B, D, H, W): + """ window_reverse function. + Args: + windows: (B*num_windows, window_size, window_size, C) + window_size (tuple[int]): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, D, H, W, C) + """ + x = windows.view(B, D // window_size[0], H // window_size[1], + W // window_size[2], window_size[0], window_size[1], + window_size[2], -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) + return x + + +def get_window_size(x_size, window_size, shift_size=None): + use_window_size = list(window_size) + if shift_size is not None: + use_shift_size = list(shift_size) + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + + +class WindowAttention3D(nn.Module): + """ This is PyTorch impl of TPS + + Window based multi-head self attention (W-MSA) module with relative position bias. + The coordinates of patches and patches are shifted together using Pattern C. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The temporal length, height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + shift (bool, optional): If True, conduct shift operation + shift_type (str, optional): shift operation type, either using 'psm' or 'tsm' + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + shift=False, + shift_type='psm'): + + super().__init__() + self.dim = dim + window_size = (16, 7, 7) + self.window_size = window_size # Wd, Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.shift = shift + self.shift_type = shift_type + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + np.prod([2 * ws - 1 for ws in window_size]), + num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_d = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack( + torch.meshgrid(coords_d, coords_h, coords_w, + indexing='ij')) # 3, Wd, Wh, Ww + # Do the same rotation to coords + coords_old = coords.clone() + + # pattern patternC - 9 + coords[:, :, 0::3, 0::3] = torch.roll( + coords[:, :, 0::3, 0::3], shifts=-4, dims=1) + coords[:, :, 0::3, 1::3] = torch.roll( + coords[:, :, 0::3, 1::3], shifts=1, dims=1) + coords[:, :, 0::3, 2::3] = torch.roll( + coords[:, :, 0::3, 2::3], shifts=2, dims=1) + coords[:, :, 1::3, 2::3] = torch.roll( + coords[:, :, 1::3, 2::3], shifts=3, dims=1) + coords[:, :, 1::3, 0::3] = torch.roll( + coords[:, :, 1::3, 0::3], shifts=-1, dims=1) + coords[:, :, 2::3, 0::3] = torch.roll( + coords[:, :, 2::3, 0::3], shifts=-2, dims=1) + coords[:, :, 2::3, 1::3] = torch.roll( + coords[:, :, 2::3, 1::3], shifts=-3, dims=1) + coords[:, :, 2::3, 2::3] = torch.roll( + coords[:, :, 2::3, 2::3], shifts=4, dims=1) + + coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww + coords_old_flatten = torch.flatten(coords_old, 1) + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww + relative_coords_old = coords_old_flatten[:, :, + None] - coords_old_flatten[:, + None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww + + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 + relative_coords_old = relative_coords_old.permute( + 1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 + + relative_coords[:, :, + 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + + relative_coords_old[:, :, 0] += self.window_size[ + 0] - 1 # shift to start from 0 + relative_coords_old[:, :, 1] += self.window_size[1] - 1 + relative_coords_old[:, :, 2] += self.window_size[2] - 1 + + relative_coords[:, :, 0] *= (2 * self.window_size[1] + - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1) + + relative_coords_old[:, :, 0] *= (2 * self.window_size[1] + - 1) * (2 * self.window_size[2] - 1) + relative_coords_old[:, :, 1] *= (2 * self.window_size[2] - 1) + + relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww + + relative_position_index_old = relative_coords_old.sum(-1) + relative_position_index = relative_position_index.view( + window_size[0], window_size[1] * window_size[2], window_size[0], + window_size[1] * window_size[2]).permute(0, 2, 1, 3).reshape( + window_size[0] * window_size[0], + window_size[1] * window_size[2], + window_size[1] * window_size[2])[::window_size[0], :, :] + + relative_position_index_old = relative_position_index_old.view( + window_size[0], window_size[1] * window_size[2], window_size[0], + window_size[1] * window_size[2]).permute(0, 2, 1, 3).reshape( + window_size[0] * window_size[0], + window_size[1] * window_size[2], + window_size[1] * window_size[2])[::window_size[0], :, :] + + self.register_buffer('relative_position_index', + relative_position_index) + self.register_buffer('relative_position_index_old', + relative_position_index_old) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + if self.shift and self.shift_type == 'psm': + self.shift_op = PatchShift(False, 1) + self.shift_op_back = PatchShift(True, 1) + elif self.shift and self.shift_type == 'tsm': + self.shift_op = TemporalShift(8) + + def forward(self, x, mask=None, batch_size=8, frame_len=8): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, N, N) or None + """ + B_, N, C = x.shape + if self.shift: + x = x.view(B_, N, self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + + x = self.shift_op(x, batch_size, frame_len) + x = x.permute(0, 2, 1, 3).reshape(B_, N, C) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if self.shift and self.shift_type == 'psm': + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index[:].reshape(-1), :].reshape( + frame_len, N, N, -1) # 8frames ,Wd*Wh*Ww,Wd*Wh*Ww,nH + else: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index_old[:].reshape(-1), :].reshape( + frame_len, N, N, -1) # 8frames ,Wd*Wh*Ww,Wd*Wh*Ww,nH + + relative_position_bias = relative_position_bias.permute( + 0, 3, 1, 2).contiguous() # Frames, nH, Wd*Wh*Ww, Wd*Wh*Ww + + attn = attn.view( + batch_size, frame_len, -1, self.num_heads, N, N).permute( + 0, + 2, 1, 3, 4, 5) + relative_position_bias.unsqueeze(0).unsqueeze( + 1) # B_, nH, N, N + attn = attn.permute(0, 2, 1, 3, 4, 5).view(-1, self.num_heads, N, N) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + # Shift back for psm + if self.shift and self.shift_type == 'psm': + x = self.shift_op_back(attn @ v, batch_size, + frame_len).transpose(1, + 2).reshape(B_, N, C) + else: + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class PatchShift(nn.Module): + """ This is PyTorch impl of TPS + + The patches are shifted using Pattern C. + + It supports both of shifted and shift back. + + Args: + inv (bool): whether using inverse shifted (shift back) + ratio (float): ratio of channels to be shifted, patch shift using 1.0 + """ + + def __init__(self, inv=False, ratio=1): + super(PatchShift, self).__init__() + self.inv = inv + self.ratio = ratio + # if inv: + # print('=> Using inverse PatchShift, ratio {}, tps'.format(ratio)) + # else: + # print('=> Using bayershift, ratio {}, tps'.format(ratio)) + + def forward(self, x, batch_size, frame_len): + x = self.shift( + x, + inv=self.inv, + ratio=self.ratio, + batch_size=batch_size, + frame_len=frame_len) + return x + + @staticmethod + def shift(x, inv=False, ratio=0.5, batch_size=8, frame_len=8): + B, num_heads, N, c = x.size() + fold = int(num_heads * ratio) + feat = x + feat = feat.view(batch_size, frame_len, -1, num_heads, 7, 7, c) + out = feat.clone() + multiplier = 1 + stride = 1 + if inv: + multiplier = -1 + + # Pattern C + out[:, :, :, :fold, 0::3, 0::3, :] = torch.roll( + feat[:, :, :, :fold, 0::3, 0::3, :], + shifts=-4 * multiplier * stride, + dims=1) + out[:, :, :, :fold, 0::3, 1::3, :] = torch.roll( + feat[:, :, :, :fold, 0::3, 1::3, :], + shifts=multiplier * stride, + dims=1) + out[:, :, :, :fold, 1::3, 0::3, :] = torch.roll( + feat[:, :, :, :fold, 1::3, 0::3, :], + shifts=-multiplier * stride, + dims=1) + out[:, :, :, :fold, 0::3, 2::3, :] = torch.roll( + feat[:, :, :, :fold, 0::3, 2::3, :], + shifts=2 * multiplier * stride, + dims=1) + out[:, :, :, :fold, 2::3, 0::3, :] = torch.roll( + feat[:, :, :, :fold, 2::3, 0::3, :], + shifts=-2 * multiplier * stride, + dims=1) + out[:, :, :, :fold, 1::3, 2::3, :] = torch.roll( + feat[:, :, :, :fold, 1::3, 2::3, :], + shifts=3 * multiplier * stride, + dims=1) + out[:, :, :, :fold, 2::3, 1::3, :] = torch.roll( + feat[:, :, :, :fold, 2::3, 1::3, :], + shifts=-3 * multiplier * stride, + dims=1) + out[:, :, :, :fold, 2::3, 2::3, :] = torch.roll( + feat[:, :, :, :fold, 2::3, 2::3, :], + shifts=4 * multiplier * stride, + dims=1) + + out = out.view(B, num_heads, N, c) + return out + + +class TemporalShift(nn.Module): + """ This is PyTorch impl of TPS + + The temporal channel shift. + + The code is adopted from TSM: Temporal Shift Module for Efficient Video Understanding. ICCV19 + + https://github.com/mit-han-lab/temporal-shift-module/blob/master/ops/temporal_shift.py + + Args: + n_div (int): propotion of channel to be shifted. + """ + + def __init__(self, n_div=8): + super(TemporalShift, self).__init__() + self.fold_div = n_div + + def forward(self, x, batch_size, frame_len): + x = self.shift( + x, + fold_div=self.fold_div, + batch_size=batch_size, + frame_len=frame_len) + return x + + @staticmethod + def shift(x, fold_div=8, batch_size=8, frame_len=8): + B, num_heads, N, c = x.size() + fold = c // fold_div + feat = x + feat = feat.view(batch_size, frame_len, -1, num_heads, N, c) + out = feat.clone() + + out[:, 1:, :, :, :, :fold] = feat[:, :-1, :, :, :, :fold] # shift left + out[:, :-1, :, :, :, + fold:2 * fold] = feat[:, 1:, :, :, :, fold:2 * fold] # shift right + + out = out.view(B, num_heads, N, c) + + return out + + +class SwinTransformerBlock3D(nn.Module): + """ Swin Transformer Block from Video Swin Transformer. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): Window size. + shift_size (tuple[int]): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=(2, 7, 7), + shift_size=(0, 0, 0), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_checkpoint=False, + shift=False, + shift_type='psm'): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_checkpoint = use_checkpoint + self.shift = shift + self.shift_type = shift_type + + assert 0 <= self.shift_size[0] < self.window_size[ + 0], 'shift_size must in 0-window_size' + assert 0 <= self.shift_size[1] < self.window_size[ + 1], 'shift_size must in 0-window_size' + assert 0 <= self.shift_size[2] < self.window_size[ + 2], 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention3D( + dim, + window_size=self.window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + shift=self.shift, + shift_type=self.shift_type) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward_part1(self, x, mask_matrix): + B, D, H, W, C = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, + self.shift_size) + + x = self.norm1(x) + # pad feature maps to multiples of window size + pad_l = pad_t = pad_d0 = 0 + pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] + pad_b = (window_size[1] - H % window_size[1]) % window_size[1] + pad_r = (window_size[2] - W % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) + _, Dp, Hp, Wp, _ = x.shape + # cyclic shift + if any(i > 0 for i in shift_size): + shifted_x = torch.roll( + x, + shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), + dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + # partition windows + x_windows = window_partition(shifted_x, + window_size) # B*nW, Wd*Wh*Ww, C + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask, batch_size=B, + frame_len=D) # B*nW, Wd*Wh*Ww, C + # merge windows + attn_windows = attn_windows.view(-1, *(window_size + (C, ))) + shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, + Wp) # B D' H' W' C + # reverse cyclic shift + if any(i > 0 for i in shift_size): + x = torch.roll( + shifted_x, + shifts=(shift_size[0], shift_size[1], shift_size[2]), + dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_d1 > 0 or pad_r > 0 or pad_b > 0: + x = x[:, :D, :H, :W, :].contiguous() + return x + + def forward_part2(self, x): + return self.drop_path(self.mlp(self.norm2(x))) + + def forward(self, x, mask_matrix): + """ Forward function. + + Args: + x: Input feature, tensor size (B, D, H, W, C). + mask_matrix: Attention mask for cyclic shift. + """ + + shortcut = x + if self.use_checkpoint: + x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) + else: + x = self.forward_part1(x, mask_matrix) + x = shortcut + self.drop_path(x) + + if self.use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_part2, x) + else: + x = x + self.forward_part2(x) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer from Video Swin Transformer. + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ Forward function. + + Args: + x: Input feature, tensor size (B, D, H, W, C). + """ + B, D, H, W, C = x.shape + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C + x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C + x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C + x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +@lru_cache() +def compute_mask(D, H, W, window_size, shift_size, device): + img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 + cnt = 0 + for d in slice(-window_size[0]), slice(-window_size[0], + -shift_size[0]), slice( + -shift_size[0], None): + for h in slice(-window_size[1]), slice(-window_size[1], + -shift_size[1]), slice( + -shift_size[1], None): + for w in slice(-window_size[2]), slice(-window_size[2], + -shift_size[2]), slice( + -shift_size[2], None): + img_mask[:, d, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, + window_size) # nW, ws[0]*ws[1]*ws[2], 1 + mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + return attn_mask + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage from Video Swin Transformer. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (tuple[int]): Local window size. Default: (1,7,7). + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=(1, 7, 7), + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + shift_type='psm'): + super().__init__() + self.window_size = window_size + self.shift_size = tuple(i // 2 for i in window_size) + self.depth = depth + self.use_checkpoint = use_checkpoint + self.shift_type = shift_type + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock3D( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + use_checkpoint=use_checkpoint, + shift=True, + shift_type='tsm' if (i % 2 == 0 and self.shift_type == 'psm') + or self.shift_type == 'tsm' else 'psm', + ) for i in range(depth) + ]) + + self.downsample = downsample + if self.downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + + def forward(self, x): + """ Forward function. + + Args: + x: Input feature, tensor size (B, C, D, H, W). + """ + # calculate attention mask for SW-MSA + B, C, D, H, W = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, + self.shift_size) + x = rearrange(x, 'b c d h w -> b d h w c') + Dp = int(np.ceil(D / window_size[0])) * window_size[0] + Hp = int(np.ceil(H / window_size[1])) * window_size[1] + Wp = int(np.ceil(W / window_size[2])) * window_size[2] + attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(B, D, H, W, -1) + + if self.downsample is not None: + x = self.downsample(x) + x = rearrange(x, 'b d h w c -> b c d h w') + return x + + +class PatchEmbed3D(nn.Module): + """ Video to Patch Embedding from Video Swin Transformer. + + Args: + patch_size (int): Patch token size. Default: (2,4,4). + in_chans (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, + patch_size=(2, 4, 4), + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, D, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, + (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if D % self.patch_size[0] != 0: + x = F.pad( + x, + (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + + x = self.proj(x) # B C D Wh Ww + if self.norm is not None: + D, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + + return x + + +class SwinTransformer2D_TPS(nn.Module): + """ + Code is adopted from Video Swin Transformer. + + Args: + patch_size (int | tuple(int)): Patch size. Default: (4,4,4). + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer: Normalization layer. Default: nn.LayerNorm. + patch_norm (bool): If True, add normalization after patch embedding. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + """ + + def __init__(self, + pretrained=None, + pretrained2d=True, + patch_size=(4, 4, 4), + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(2, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=False, + frozen_stages=-1, + use_checkpoint=False): + super().__init__() + + self.pretrained = pretrained + self.pretrained2d = pretrained2d + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.frozen_stages = frozen_stages + self.window_size = window_size + self.patch_size = patch_size + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed3D( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging + if i_layer < self.num_layers - 1 else None, + use_checkpoint=use_checkpoint, + shift_type='psm') + self.layers.append(layer) + + self.num_features = int(embed_dim * 2**(self.num_layers - 1)) + + # add a norm layer for each output + self.norm = norm_layer(self.num_features) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1: + self.pos_drop.eval() + for i in range(0, self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def inflate_weights(self): + """Inflate the swin2d parameters to swin3d. + + The differences between swin3d and swin2d mainly lie in an extra + axis. To utilize the pretrained parameters in 2d model, + the weight of swin2d models should be inflated to fit in the shapes of + the 3d counterpart. + + Args: + logger (logging.Logger): The logger used to print + debugging infomation. + """ + checkpoint = torch.load(self.pretrained, map_location='cpu') + state_dict = checkpoint['model'] + + # delete relative_position_index since we always re-init it + relative_position_index_keys = [ + k for k in state_dict.keys() if 'relative_position_index' in k + ] + for k in relative_position_index_keys: + del state_dict[k] + + # delete attn_mask since we always re-init it + attn_mask_keys = [k for k in state_dict.keys() if 'attn_mask' in k] + for k in attn_mask_keys: + del state_dict[k] + + state_dict['patch_embed.proj.weight'] = state_dict[ + 'patch_embed.proj.weight'].unsqueeze(2).repeat( + 1, 1, self.patch_size[0], 1, 1) / self.patch_size[0] + + # bicubic interpolate relative_position_bias_table if not match + relative_position_bias_table_keys = [ + k for k in state_dict.keys() if 'relative_position_bias_table' in k + ] + for k in relative_position_bias_table_keys: + relative_position_bias_table_pretrained = state_dict[k] + relative_position_bias_table_current = self.state_dict()[k] + L1, nH1 = relative_position_bias_table_pretrained.size() + L2, nH2 = relative_position_bias_table_current.size() + L2 = (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + # wd = self.window_size[0] + # to make it match + wd = 16 + if nH1 != nH2: + print(f'Error in loading {k}, passing') + else: + if L1 != L2: + S1 = int(L1**0.5) + relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( + relative_position_bias_table_pretrained.permute( + 1, 0).view(1, nH1, S1, S1), + size=(2 * self.window_size[1] - 1, + 2 * self.window_size[2] - 1), + mode='bicubic') + relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view( + nH2, L2).permute(1, 0) + state_dict[k] = relative_position_bias_table_pretrained.repeat( + 2 * wd - 1, 1) + + msg = self.load_state_dict(state_dict, strict=False) + print(msg) + print(f"=> loaded successfully '{self.pretrained}'") + del checkpoint + torch.cuda.empty_cache() + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if pretrained: + self.pretrained = pretrained + if isinstance(self.pretrained, str): + self.apply(_init_weights) + print(f'load model from: {self.pretrained}') + + if self.pretrained2d: + # Inflate 2D model into 3D model. + # self.inflate_weights(logger) + self.inflate_weights() + else: + # Directly load 3D model. + torch.load_checkpoint(self, self.pretrained, strict=False) + elif self.pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x.contiguous()) + + x = rearrange(x, 'n c d h w -> n d h w c') + x = self.norm(x) + x = rearrange(x, 'n d h w c -> n c d h w') + + return x + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer2D_TPS, self).train(mode) + self._freeze_stages() + + +def top_k_accuracy(scores, labels, topk=(1, )): + """Calculate top k accuracy score from mmaction. + + Args: + scores (list[np.ndarray]): Prediction scores for each class. + labels (list[int]): Ground truth labels. + topk (tuple[int]): K value for top_k_accuracy. Default: (1, ). + + Returns: + list[float]: Top k accuracy score for each k. + """ + res = [] + labels = np.array(labels)[:, np.newaxis] + for k in topk: + max_k_preds = np.argsort(scores, axis=1)[:, -k:][:, ::-1] + match_array = np.logical_or.reduce(max_k_preds == labels, axis=1) + topk_acc_score = match_array.sum() / match_array.shape[0] + res.append(topk_acc_score) + + return res + + +class BaseHead(nn.Module, metaclass=ABCMeta): + """Base class for head from mmaction. + + All Head should subclass it. + All subclass should overwrite: + - Methods:``init_weights``, initializing weights in some modules. + - Methods:``forward``, supporting to forward both for training and testing. + + Args: + num_classes (int): Number of classes to be classified. + in_channels (int): Number of channels in input feature. + loss_cls (dict): Config for building loss. + Default: dict(type='CrossEntropyLoss', loss_weight=1.0). + multi_class (bool): Determines whether it is a multi-class + recognition task. Default: False. + label_smooth_eps (float): Epsilon used in label smooth. + Reference: arxiv.org/abs/1906.02629. Default: 0. + """ + + def __init__(self, + num_classes, + in_channels, + loss_cls=dict(type='CrossEntropyLoss', loss_weight=1.0), + multi_class=False, + label_smooth_eps=0.0): + super().__init__() + self.num_classes = num_classes + self.in_channels = in_channels + self.loss_cls = torch.nn.CrossEntropyLoss() + self.multi_class = multi_class + self.label_smooth_eps = label_smooth_eps + + @abstractmethod + def init_weights(self): + """Initiate the parameters either from existing checkpoint or from + scratch.""" + + @abstractmethod + def forward(self, x): + """Defines the computation performed at every call.""" + + def loss(self, cls_score, labels, **kwargs): + """Calculate the loss given output ``cls_score``, target ``labels``. + + Args: + cls_score (torch.Tensor): The output of the model. + labels (torch.Tensor): The target output of the model. + + Returns: + dict: A dict containing field 'loss_cls'(mandatory) + and 'top1_acc', 'top5_acc'(optional). + """ + losses = dict() + if labels.shape == torch.Size([]): + labels = labels.unsqueeze(0) + elif labels.dim() == 1 and labels.size()[0] == self.num_classes \ + and cls_score.size()[0] == 1: + # Fix a bug when training with soft labels and batch size is 1. + # When using soft labels, `labels` and `cls_socre` share the same + # shape. + labels = labels.unsqueeze(0) + + if not self.multi_class and cls_score.size() != labels.size(): + top_k_acc = top_k_accuracy(cls_score.detach().cpu().numpy(), + labels.detach().cpu().numpy(), (1, 5)) + losses['top1_acc'] = torch.tensor( + top_k_acc[0], device=cls_score.device) + losses['top5_acc'] = torch.tensor( + top_k_acc[1], device=cls_score.device) + + elif self.multi_class and self.label_smooth_eps != 0: + labels = ((1 - self.label_smooth_eps) * labels + + self.label_smooth_eps / self.num_classes) + + loss_cls = self.loss_cls(cls_score, labels, **kwargs) + # loss_cls may be dictionary or single tensor + if isinstance(loss_cls, dict): + losses.update(loss_cls) + else: + losses['loss_cls'] = loss_cls + + return losses + + +class I3DHead(BaseHead): + """Classification head for I3D from mmaction. + + Args: + num_classes (int): Number of classes to be classified. + in_channels (int): Number of channels in input feature. + loss_cls (dict): Config for building loss. + Default: dict(type='CrossEntropyLoss') + spatial_type (str): Pooling type in spatial dimension. Default: 'avg'. + dropout_ratio (float): Probability of dropout layer. Default: 0.5. + init_std (float): Std value for Initiation. Default: 0.01. + kwargs (dict, optional): Any keyword argument to be used to initialize + the head. + """ + + def __init__(self, + num_classes, + in_channels, + loss_cls=dict(type='CrossEntropyLoss'), + spatial_type='avg', + dropout_ratio=0.5, + init_std=0.01, + **kwargs): + super().__init__(num_classes, in_channels, loss_cls, **kwargs) + + self.spatial_type = spatial_type + self.dropout_ratio = dropout_ratio + self.init_std = init_std + if self.dropout_ratio != 0: + self.dropout = nn.Dropout(p=self.dropout_ratio) + else: + self.dropout = None + self.fc_cls = nn.Linear(self.in_channels, self.num_classes) + + if self.spatial_type == 'avg': + # use `nn.AdaptiveAvgPool3d` to adaptively match the in_channels. + self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) + else: + self.avg_pool = None + + def init_weights(self): + """Initiate the parameters from scratch.""" + normal_init(self.fc_cls, std=self.init_std) + + def forward(self, x): + """Defines the computation performed at every call. + + Args: + x (torch.Tensor): The input data. + + Returns: + torch.Tensor: The classification scores for input samples. + """ + # [N, in_channels, 4, 7, 7] + if self.avg_pool is not None: + x = self.avg_pool(x) + # [N, in_channels, 1, 1, 1] + if self.dropout is not None: + x = self.dropout(x) + # [N, in_channels, 1, 1, 1] + x = x.view(x.shape[0], -1) + # [N, in_channels] + cls_score = self.fc_cls(x) + # [N, num_classes] + return cls_score + + +class PatchShiftTransformer(TorchModel): + """ This is PyTorch impl of PST: + Spatiotemporal Self-attention Modeling with Temporal Patch Shift for Action Recognition, ECCV22. + """ + + def __init__(self, + model_dir=None, + num_classes=400, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + embed_dim=96, + in_channels=768, + pretrained=None): + super().__init__(model_dir) + self.backbone = SwinTransformer2D_TPS( + pretrained=pretrained, + pretrained2d=True, + patch_size=(2, 4, 4), + in_chans=3, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=(1, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + frozen_stages=-1, + use_checkpoint=False) + self.cls_head = I3DHead( + num_classes=num_classes, in_channels=in_channels) + + def forward(self, x): + feature = self.backbone(x) + output = self.cls_head(feature) + return output diff --git a/modelscope/models/cv/body_2d_keypoints/hrnet_v2.py b/modelscope/models/cv/body_2d_keypoints/hrnet_v2.py index 1570c8cc..ebd69adb 100644 --- a/modelscope/models/cv/body_2d_keypoints/hrnet_v2.py +++ b/modelscope/models/cv/body_2d_keypoints/hrnet_v2.py @@ -1,3 +1,5 @@ +# The implementation is based on HRNET, available at https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation. + import os import numpy as np diff --git a/modelscope/models/cv/body_2d_keypoints/w48.py b/modelscope/models/cv/body_2d_keypoints/w48.py index 7140f8fe..e0317991 100644 --- a/modelscope/models/cv/body_2d_keypoints/w48.py +++ b/modelscope/models/cv/body_2d_keypoints/w48.py @@ -1,3 +1,5 @@ +# The implementation is based on HRNET, available at https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation. + cfg_128x128_15 = { 'DATASET': { 'TYPE': 'DAMO', diff --git a/modelscope/models/cv/body_3d_keypoints/body_3d_pose.py b/modelscope/models/cv/body_3d_keypoints/body_3d_pose.py index 87cd4962..3e920d12 100644 --- a/modelscope/models/cv/body_3d_keypoints/body_3d_pose.py +++ b/modelscope/models/cv/body_3d_keypoints/body_3d_pose.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import logging import os.path as osp from typing import Any, Dict, List, Union diff --git a/modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py b/modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py index b3eac2e5..b7f0c4a3 100644 --- a/modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py +++ b/modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py @@ -1,4 +1,4 @@ -# The implementation is based on OSTrack, available at https://github.com/facebookresearch/VideoPose3D +# The implementation is based on VideoPose3D, available at https://github.com/facebookresearch/VideoPose3D import torch import torch.nn as nn diff --git a/modelscope/models/cv/cartoon/facelib/LK/lk.py b/modelscope/models/cv/cartoon/facelib/LK/lk.py index df05e3f9..6fd95ad6 100644 --- a/modelscope/models/cv/cartoon/facelib/LK/lk.py +++ b/modelscope/models/cv/cartoon/facelib/LK/lk.py @@ -1,3 +1,5 @@ +# The implementation is adopted from https://github.com/610265158/Peppa_Pig_Face_Engine + import numpy as np from modelscope.models.cv.cartoon.facelib.config import config as cfg diff --git a/modelscope/models/cv/cartoon/facelib/config.py b/modelscope/models/cv/cartoon/facelib/config.py index d795fdde..92b39db0 100644 --- a/modelscope/models/cv/cartoon/facelib/config.py +++ b/modelscope/models/cv/cartoon/facelib/config.py @@ -1,3 +1,5 @@ +# The implementation is adopted from https://github.com/610265158/Peppa_Pig_Face_Engine + import os import numpy as np diff --git a/modelscope/models/cv/cartoon/facelib/face_detector.py b/modelscope/models/cv/cartoon/facelib/face_detector.py index e5589719..fa36d662 100644 --- a/modelscope/models/cv/cartoon/facelib/face_detector.py +++ b/modelscope/models/cv/cartoon/facelib/face_detector.py @@ -1,3 +1,5 @@ +# The implementation is adopted from https://github.com/610265158/Peppa_Pig_Face_Engine + import time import cv2 diff --git a/modelscope/models/cv/cartoon/facelib/face_landmark.py b/modelscope/models/cv/cartoon/facelib/face_landmark.py index 063d40c3..3b7cc1b9 100644 --- a/modelscope/models/cv/cartoon/facelib/face_landmark.py +++ b/modelscope/models/cv/cartoon/facelib/face_landmark.py @@ -1,3 +1,5 @@ +# The implementation is adopted from https://github.com/610265158/Peppa_Pig_Face_Engine + import cv2 import numpy as np import tensorflow as tf diff --git a/modelscope/models/cv/cartoon/facelib/facer.py b/modelscope/models/cv/cartoon/facelib/facer.py index 62388ab9..c6f34e9c 100644 --- a/modelscope/models/cv/cartoon/facelib/facer.py +++ b/modelscope/models/cv/cartoon/facelib/facer.py @@ -1,3 +1,5 @@ +# The implementation is adopted from https://github.com/610265158/Peppa_Pig_Face_Engine + import time import cv2 diff --git a/modelscope/models/cv/cartoon/mtcnn_pytorch/src/align_trans.py b/modelscope/models/cv/cartoon/mtcnn_pytorch/src/align_trans.py index baa3ba73..eb542042 100644 --- a/modelscope/models/cv/cartoon/mtcnn_pytorch/src/align_trans.py +++ b/modelscope/models/cv/cartoon/mtcnn_pytorch/src/align_trans.py @@ -1,7 +1,5 @@ -""" -Created on Mon Apr 24 15:43:29 2017 -@author: zhaoy -""" +# The implementation is adopted from https://github.com/TreB1eN/InsightFace_Pytorch/tree/master/mtcnn_pytorch + import cv2 import numpy as np diff --git a/modelscope/models/cv/cartoon/mtcnn_pytorch/src/matlab_cp2tform.py b/modelscope/models/cv/cartoon/mtcnn_pytorch/src/matlab_cp2tform.py index 96a5f965..ea9fbacf 100644 --- a/modelscope/models/cv/cartoon/mtcnn_pytorch/src/matlab_cp2tform.py +++ b/modelscope/models/cv/cartoon/mtcnn_pytorch/src/matlab_cp2tform.py @@ -1,8 +1,4 @@ -""" -Created on Tue Jul 11 06:54:28 2017 - -@author: zhaoyafei -""" +# The implementation is adopted from https://github.com/TreB1eN/InsightFace_Pytorch/tree/master/mtcnn_pytorch import numpy as np from numpy.linalg import inv, lstsq diff --git a/modelscope/models/cv/cartoon/utils.py b/modelscope/models/cv/cartoon/utils.py index 39712653..59b4e879 100644 --- a/modelscope/models/cv/cartoon/utils.py +++ b/modelscope/models/cv/cartoon/utils.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os import cv2 diff --git a/modelscope/models/cv/face_detection/mogface/__init__.py b/modelscope/models/cv/face_detection/mogface/__init__.py index 8190b649..a58268d0 100644 --- a/modelscope/models/cv/face_detection/mogface/__init__.py +++ b/modelscope/models/cv/face_detection/mogface/__init__.py @@ -1 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from .models.detectors import MogFaceDetector diff --git a/modelscope/models/cv/face_detection/mtcnn/__init__.py b/modelscope/models/cv/face_detection/mtcnn/__init__.py index b11c4740..9fddab9c 100644 --- a/modelscope/models/cv/face_detection/mtcnn/__init__.py +++ b/modelscope/models/cv/face_detection/mtcnn/__init__.py @@ -1 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from .models.detector import MtcnnFaceDetector diff --git a/modelscope/models/cv/face_detection/retinaface/__init__.py b/modelscope/models/cv/face_detection/retinaface/__init__.py index 779aaf1c..e7b589a1 100644 --- a/modelscope/models/cv/face_detection/retinaface/__init__.py +++ b/modelscope/models/cv/face_detection/retinaface/__init__.py @@ -1 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from .detection import RetinaFaceDetection diff --git a/modelscope/models/cv/face_detection/ulfd_slim/__init__.py b/modelscope/models/cv/face_detection/ulfd_slim/__init__.py index 41a2226a..af1e7b42 100644 --- a/modelscope/models/cv/face_detection/ulfd_slim/__init__.py +++ b/modelscope/models/cv/face_detection/ulfd_slim/__init__.py @@ -1 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from .detection import UlfdFaceDetector diff --git a/modelscope/models/cv/face_emotion/__init__.py b/modelscope/models/cv/face_emotion/__init__.py new file mode 100644 index 00000000..2a13ea42 --- /dev/null +++ b/modelscope/models/cv/face_emotion/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .emotion_model import EfficientNetForFaceEmotion + +else: + _import_structure = {'emotion_model': ['EfficientNetForFaceEmotion']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/face_emotion/efficient/__init__.py b/modelscope/models/cv/face_emotion/efficient/__init__.py new file mode 100644 index 00000000..e8fc91a4 --- /dev/null +++ b/modelscope/models/cv/face_emotion/efficient/__init__.py @@ -0,0 +1,6 @@ +# The implementation here is modified based on EfficientNet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/lukemelas/EfficientNet-PyTorch + +from .model import VALID_MODELS, EfficientNet +from .utils import (BlockArgs, BlockDecoder, GlobalParams, efficientnet, + get_model_params) diff --git a/modelscope/models/cv/face_emotion/efficient/model.py b/modelscope/models/cv/face_emotion/efficient/model.py new file mode 100644 index 00000000..db303016 --- /dev/null +++ b/modelscope/models/cv/face_emotion/efficient/model.py @@ -0,0 +1,380 @@ +# The implementation here is modified based on EfficientNet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/lukemelas/EfficientNet-PyTorch + +import torch +from torch import nn +from torch.nn import functional as F + +from .utils import (MemoryEfficientSwish, Swish, calculate_output_image_size, + drop_connect, efficientnet_params, get_model_params, + get_same_padding_conv2d, load_pretrained_weights, + round_filters, round_repeats) + +VALID_MODELS = ('efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', + 'efficientnet-b3', 'efficientnet-b4', 'efficientnet-b5', + 'efficientnet-b6', 'efficientnet-b7', 'efficientnet-b8', + 'efficientnet-l2') + + +class MBConvBlock(nn.Module): + + def __init__(self, block_args, global_params, image_size=None): + super().__init__() + self._block_args = block_args + self._bn_mom = 1 - global_params.batch_norm_momentum + self._bn_eps = global_params.batch_norm_epsilon + self.has_se = (self._block_args.se_ratio + is not None) and (0 < self._block_args.se_ratio <= 1) + self.id_skip = block_args.id_skip + + inp = self._block_args.input_filters + oup = self._block_args.input_filters * self._block_args.expand_ratio + if self._block_args.expand_ratio != 1: + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._expand_conv = Conv2d( + in_channels=inp, out_channels=oup, kernel_size=1, bias=False) + self._bn0 = nn.BatchNorm2d( + num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + + k = self._block_args.kernel_size + s = self._block_args.stride + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._depthwise_conv = Conv2d( + in_channels=oup, + out_channels=oup, + groups=oup, + kernel_size=k, + stride=s, + bias=False) + self._bn1 = nn.BatchNorm2d( + num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + image_size = calculate_output_image_size(image_size, s) + + if self.has_se: + Conv2d = get_same_padding_conv2d(image_size=(1, 1)) + num_squeezed_channels = max( + 1, + int(self._block_args.input_filters + * self._block_args.se_ratio)) + self._se_reduce = Conv2d( + in_channels=oup, + out_channels=num_squeezed_channels, + kernel_size=1) + self._se_expand = Conv2d( + in_channels=num_squeezed_channels, + out_channels=oup, + kernel_size=1) + + final_oup = self._block_args.output_filters + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._project_conv = Conv2d( + in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) + self._bn2 = nn.BatchNorm2d( + num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) + self._swish = MemoryEfficientSwish() + + def forward(self, inputs, drop_connect_rate=None): + """MBConvBlock's forward function. + Args: + inputs (tensor): Input tensor. + drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). + Returns: + Output of this block after processing. + """ + + x = inputs + if self._block_args.expand_ratio != 1: + x = self._expand_conv(inputs) + x = self._bn0(x) + x = self._swish(x) + + x = self._depthwise_conv(x) + x = self._bn1(x) + x = self._swish(x) + + if self.has_se: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self._se_reduce(x_squeezed) + x_squeezed = self._swish(x_squeezed) + x_squeezed = self._se_expand(x_squeezed) + x = torch.sigmoid(x_squeezed) * x + + x = self._project_conv(x) + x = self._bn2(x) + + input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters + if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: + if drop_connect_rate: + x = drop_connect( + x, p=drop_connect_rate, training=self.training) + x = x + inputs + return x + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export). + Args: + memory_efficient (bool): Whether to use memory-efficient version of swish. + """ + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + + +class EfficientNet(nn.Module): + """EfficientNet model. + Most easily loaded with the .from_name or .from_pretrained methods. + Args: + blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. + global_params (namedtuple): A set of GlobalParams shared between blocks. + References: + [1] https://arxiv.org/abs/1905.11946 (EfficientNet) + Example: + >>> import torch + >>> from efficientnet.model import EfficientNet + >>> inputs = torch.rand(1, 3, 224, 224) + >>> model = EfficientNet.from_pretrained('efficientnet-b0') + >>> model.eval() + >>> outputs = model(inputs) + """ + + def __init__(self, blocks_args=None, global_params=None): + super().__init__() + assert isinstance(blocks_args, list), 'blocks_args should be a list' + assert len(blocks_args) > 0, 'block args must be greater than 0' + self._global_params = global_params + self._blocks_args = blocks_args + + bn_mom = 1 - self._global_params.batch_norm_momentum + bn_eps = self._global_params.batch_norm_epsilon + image_size = global_params.image_size + Conv2d = get_same_padding_conv2d(image_size=image_size) + + in_channels = 3 + out_channels = round_filters(32, self._global_params) + self._conv_stem = Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, bias=False) + self._bn0 = nn.BatchNorm2d( + num_features=out_channels, momentum=bn_mom, eps=bn_eps) + image_size = calculate_output_image_size(image_size, 2) + + self._blocks = nn.ModuleList([]) + for block_args in self._blocks_args: + + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, + self._global_params), + output_filters=round_filters(block_args.output_filters, + self._global_params), + num_repeat=round_repeats(block_args.num_repeat, + self._global_params)) + + self._blocks.append( + MBConvBlock( + block_args, self._global_params, image_size=image_size)) + image_size = calculate_output_image_size(image_size, + block_args.stride) + if block_args.num_repeat > 1: + block_args = block_args._replace( + input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + self._blocks.append( + MBConvBlock( + block_args, self._global_params, + image_size=image_size)) + + in_channels = block_args.output_filters + out_channels = round_filters(1280, self._global_params) + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._conv_head = Conv2d( + in_channels, out_channels, kernel_size=1, bias=False) + self._bn1 = nn.BatchNorm2d( + num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + self._avg_pooling = nn.AdaptiveAvgPool2d(1) + if self._global_params.include_top: + self._dropout = nn.Dropout(self._global_params.dropout_rate) + self._fc = nn.Linear(out_channels, self._global_params.num_classes) + + self._swish = MemoryEfficientSwish() + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export). + Args: + memory_efficient (bool): Whether to use memory-efficient version of swish. + """ + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + for block in self._blocks: + block.set_swish(memory_efficient) + + def extract_endpoints(self, inputs): + """Use convolution layer to extract features + from reduction levels i in [1, 2, 3, 4, 5]. + Args: + inputs (tensor): Input tensor. + Returns: + Dictionary of last intermediate features + with reduction levels i in [1, 2, 3, 4, 5]. + Example: + >>> import torch + >>> from efficientnet.model import EfficientNet + >>> inputs = torch.rand(1, 3, 224, 224) + >>> model = EfficientNet.from_pretrained('efficientnet-b0') + >>> endpoints = model.extract_endpoints(inputs) + >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) + >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) + >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) + >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) + >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7]) + >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7]) + """ + endpoints = dict() + + x = self._swish(self._bn0(self._conv_stem(inputs))) + prev_x = x + + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len( + self._blocks) # scale drop connect_rate + x = block(x, drop_connect_rate=drop_connect_rate) + if prev_x.size(2) > x.size(2): + endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x + elif idx == len(self._blocks) - 1: + endpoints['reduction_{}'.format(len(endpoints) + 1)] = x + prev_x = x + + x = self._swish(self._bn1(self._conv_head(x))) + endpoints['reduction_{}'.format(len(endpoints) + 1)] = x + + return endpoints + + def extract_features(self, inputs): + """use convolution layer to extract feature . + Args: + inputs (tensor): Input tensor. + Returns: + Output of the final convolution + layer in the efficientnet model. + """ + x = self._swish(self._bn0(self._conv_stem(inputs))) + + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) + x = block(x, drop_connect_rate=drop_connect_rate) + x = self._swish(self._bn1(self._conv_head(x))) + + return x + + def forward(self, inputs): + """EfficientNet's forward function. + Calls extract_features to extract features, applies final linear layer, and returns logits. + Args: + inputs (tensor): Input tensor. + Returns: + Output of this model after processing. + """ + x = self.extract_features(inputs) + x = self._avg_pooling(x) + if self._global_params.include_top: + x = x.flatten(start_dim=1) + x = self._dropout(x) + x = self._fc(x) + return x + + @classmethod + def from_name(cls, model_name, in_channels=3, **override_params): + """Create an efficientnet model according to name. + Args: + model_name (str): Name for efficientnet. + in_channels (int): Input data's channel number. + override_params (other key word params): + Params to override model's global_params. + Optional key: + 'width_coefficient', 'depth_coefficient', + 'image_size', 'dropout_rate', + 'num_classes', 'batch_norm_momentum', + 'batch_norm_epsilon', 'drop_connect_rate', + 'depth_divisor', 'min_depth' + Returns: + An efficientnet model. + """ + cls._check_model_name_is_valid(model_name) + blocks_args, global_params = get_model_params(model_name, + override_params) + model = cls(blocks_args, global_params) + model._change_in_channels(in_channels) + return model + + @classmethod + def from_pretrained(cls, + model_name, + weights_path=None, + advprop=False, + in_channels=3, + num_classes=1000, + **override_params): + """Create an efficientnet model according to name. + Args: + model_name (str): Name for efficientnet. + weights_path (None or str): + str: path to pretrained weights file on the local disk. + None: use pretrained weights downloaded from the Internet. + advprop (bool): + Whether to load pretrained weights + trained with advprop (valid when weights_path is None). + in_channels (int): Input data's channel number. + num_classes (int): + Number of categories for classification. + It controls the output size for final linear layer. + override_params (other key word params): + Params to override model's global_params. + Optional key: + 'width_coefficient', 'depth_coefficient', + 'image_size', 'dropout_rate', + 'batch_norm_momentum', + 'batch_norm_epsilon', 'drop_connect_rate', + 'depth_divisor', 'min_depth' + Returns: + A pretrained efficientnet model. + """ + model = cls.from_name( + model_name, num_classes=num_classes, **override_params) + model._change_in_channels(in_channels) + return model + + @classmethod + def get_image_size(cls, model_name): + """Get the input image size for a given efficientnet model. + Args: + model_name (str): Name for efficientnet. + Returns: + Input image size (resolution). + """ + cls._check_model_name_is_valid(model_name) + _, _, res, _ = efficientnet_params(model_name) + return res + + @classmethod + def _check_model_name_is_valid(cls, model_name): + """Validates model name. + Args: + model_name (str): Name for efficientnet. + Returns: + bool: Is a valid name or not. + """ + if model_name not in VALID_MODELS: + raise ValueError('model_name should be one of: ' + + ', '.join(VALID_MODELS)) + + def _change_in_channels(self, in_channels): + """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. + Args: + in_channels (int): Input data's channel number. + """ + if in_channels != 3: + Conv2d = get_same_padding_conv2d( + image_size=self._global_params.image_size) + out_channels = round_filters(32, self._global_params) + self._conv_stem = Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, bias=False) diff --git a/modelscope/models/cv/face_emotion/efficient/utils.py b/modelscope/models/cv/face_emotion/efficient/utils.py new file mode 100644 index 00000000..6cae70fc --- /dev/null +++ b/modelscope/models/cv/face_emotion/efficient/utils.py @@ -0,0 +1,559 @@ +# The implementation here is modified based on EfficientNet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/lukemelas/EfficientNet-PyTorch + +import collections +import math +import re +from functools import partial + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils import model_zoo + +GlobalParams = collections.namedtuple('GlobalParams', [ + 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate', + 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon', + 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top' +]) + +BlockArgs = collections.namedtuple('BlockArgs', [ + 'num_repeat', 'kernel_size', 'stride', 'expand_ratio', 'input_filters', + 'output_filters', 'se_ratio', 'id_skip' +]) + +GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields) + +if hasattr(nn, 'SiLU'): + Swish = nn.SiLU +else: + + class Swish(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(x) + + +class SwishImplementation(torch.autograd.Function): + + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_tensors[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class MemoryEfficientSwish(nn.Module): + + def forward(self, x): + return SwishImplementation.apply(x) + + +def round_filters(filters, global_params): + """Calculate and round number of filters based on width multiplier. + Use width_coefficient, depth_divisor and min_depth of global_params. + Args: + filters (int): Filters number to be calculated. + global_params (namedtuple): Global params of the model. + Returns: + new_filters: New filters number after calculating. + """ + multiplier = global_params.width_coefficient + if not multiplier: + return filters + + divisor = global_params.depth_divisor + min_depth = global_params.min_depth + filters *= multiplier + min_depth = min_depth or divisor + new_filters = max(min_depth, + int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats, global_params): + """Calculate module's repeat number of a block based on depth multiplier. + Use depth_coefficient of global_params. + Args: + repeats (int): num_repeat to be calculated. + global_params (namedtuple): Global params of the model. + Returns: + new repeat: New repeat number after calculating. + """ + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + return int(math.ceil(multiplier * repeats)) + + +def drop_connect(inputs, p, training): + """Drop connect. + Args: + input (tensor: BCWH): Input of this structure. + p (float: 0.0~1.0): Probability of drop connection. + training (bool): The running mode. + Returns: + output: Output after drop connection. + """ + assert 0 <= p <= 1, 'p must be in range of [0,1]' + + if not training: + return inputs + + batch_size = inputs.shape[0] + keep_prob = 1 - p + + random_tensor = keep_prob + random_tensor += torch.rand([batch_size, 1, 1, 1], + dtype=inputs.dtype, + device=inputs.device) + binary_tensor = torch.floor(random_tensor) + + output = inputs / keep_prob * binary_tensor + return output + + +def get_width_and_height_from_size(x): + """Obtain height and width from x. + Args: + x (int, tuple or list): Data size. + Returns: + size: A tuple or list (H,W). + """ + if isinstance(x, int): + return x, x + if isinstance(x, list) or isinstance(x, tuple): + return x + else: + raise TypeError() + + +def calculate_output_image_size(input_image_size, stride): + """Calculates the output image size when using Conv2dSamePadding with a stride. + Necessary for static padding. Thanks to mannatsingh for pointing this out. + Args: + input_image_size (int, tuple or list): Size of input image. + stride (int, tuple or list): Conv2d operation's stride. + Returns: + output_image_size: A list [H,W]. + """ + if input_image_size is None: + return None + image_height, image_width = get_width_and_height_from_size( + input_image_size) + stride = stride if isinstance(stride, int) else stride[0] + image_height = int(math.ceil(image_height / stride)) + image_width = int(math.ceil(image_width / stride)) + return [image_height, image_width] + + +def get_same_padding_conv2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + Args: + image_size (int or tuple): Size of the image. + Returns: + Conv2dDynamicSamePadding or Conv2dStaticSamePadding. + """ + if image_size is None: + return Conv2dDynamicSamePadding + else: + return partial(Conv2dStaticSamePadding, image_size=image_size) + + +class Conv2dDynamicSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow, for a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True): + super().__init__(in_channels, out_channels, kernel_size, stride, 0, + dilation, groups, bias) + self.stride = self.stride if len( + self.stride) == 2 else [self.stride[0]] * 2 + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + a1 = (oh - 1) * self.stride[0] + pad_h = max(a1 + (kh - 1) * self.dilation[0] + 1 - ih, 0) + a2 = (ow - 1) * self.stride[1] + pad_w = max(a2 + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 + ]) + return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +class Conv2dStaticSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + image_size=None, + **kwargs): + super().__init__(in_channels, out_channels, kernel_size, stride, + **kwargs) + self.stride = self.stride if len( + self.stride) == 2 else [self.stride[0]] * 2 + + assert image_size is not None + ih, iw = (image_size, + image_size) if isinstance(image_size, int) else image_size + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + b1 = (oh - 1) * self.stride[0] + pad_h = max(b1 + (kh - 1) * self.dilation[0] + 1 - ih, 0) + b2 = (ow - 1) * self.stride[1] + pad_w = max(b2 + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d( + (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2)) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups) + return x + + +def get_same_padding_maxPool2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + Args: + image_size (int or tuple): Size of the image. + Returns: + MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding. + """ + if image_size is None: + return MaxPool2dDynamicSamePadding + else: + return partial(MaxPool2dStaticSamePadding, image_size=image_size) + + +class MaxPool2dDynamicSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + def __init__(self, + kernel_size, + stride, + padding=0, + dilation=1, + return_indices=False, + ceil_mode=False): + super().__init__(kernel_size, stride, padding, dilation, + return_indices, ceil_mode) + self.stride = [self.stride] * 2 if isinstance(self.stride, + int) else self.stride + self.kernel_size = [self.kernel_size] * 2 if isinstance( + self.kernel_size, int) else self.kernel_size + self.dilation = [self.dilation] * 2 if isinstance( + self.dilation, int) else self.dilation + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + c1 = (oh - 1) * self.stride[0] + pad_h = max(c1 + (kh - 1) * self.dilation[0] + 1 - ih, 0) + c2 = (ow - 1) * self.stride[1] + pad_w = max(c2 + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 + ]) + return F.max_pool2d(x, self.kernel_size, self.stride, self.padding, + self.dilation, self.ceil_mode, self.return_indices) + + +class MaxPool2dStaticSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + def __init__(self, kernel_size, stride, image_size=None, **kwargs): + super().__init__(kernel_size, stride, **kwargs) + self.stride = [self.stride] * 2 if isinstance(self.stride, + int) else self.stride + self.kernel_size = [self.kernel_size] * 2 if isinstance( + self.kernel_size, int) else self.kernel_size + self.dilation = [self.dilation] * 2 if isinstance( + self.dilation, int) else self.dilation + + assert image_size is not None + ih, iw = (image_size, + image_size) if isinstance(image_size, int) else image_size + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + d1 = (oh - 1) * self.stride[0] + pad_h = max(d1 + (kh - 1) * self.dilation[0] + 1 - ih, 0) + d2 = (ow - 1) * self.stride[1] + pad_w = max(d2 + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d( + (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2)) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding, + self.dilation, self.ceil_mode, self.return_indices) + return x + + +class BlockDecoder(object): + """Block Decoder for readability, + straight from the official TensorFlow repository. + """ + + @staticmethod + def _decode_block_string(block_string): + """Get a block through a string notation of arguments. + Args: + block_string (str): A string notation of arguments. + Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. + Returns: + BlockArgs: The namedtuple defined at the top of this file. + """ + assert isinstance(block_string, str) + + ops = block_string.split('_') + options = {} + for op in ops: + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # Check stride + assert (('s' in options and len(options['s']) == 1) + or (len(options['s']) == 2 + and options['s'][0] == options['s'][1])) + + return BlockArgs( + num_repeat=int(options['r']), + kernel_size=int(options['k']), + stride=[int(options['s'][0])], + expand_ratio=int(options['e']), + input_filters=int(options['i']), + output_filters=int(options['o']), + se_ratio=float(options['se']) if 'se' in options else None, + id_skip=('noskip' not in block_string)) + + @staticmethod + def _encode_block_string(block): + """Encode a block to a string. + Args: + block (namedtuple): A BlockArgs type argument. + Returns: + block_string: A String form of BlockArgs. + """ + args = [ + 'r%d' % block.num_repeat, + 'k%d' % block.kernel_size, + 's%d%d' % (block.strides[0], block.strides[1]), + 'e%s' % block.expand_ratio, + 'i%d' % block.input_filters, + 'o%d' % block.output_filters + ] + if 0 < block.se_ratio <= 1: + args.append('se%s' % block.se_ratio) + if block.id_skip is False: + args.append('noskip') + return '_'.join(args) + + @staticmethod + def decode(string_list): + """Decode a list of string notations to specify blocks inside the network. + Args: + string_list (list[str]): A list of strings, each string is a notation of block. + Returns: + blocks_args: A list of BlockArgs namedtuples of block args. + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(BlockDecoder._decode_block_string(block_string)) + return blocks_args + + @staticmethod + def encode(blocks_args): + """Encode a list of BlockArgs to a list of strings. + Args: + blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. + Returns: + block_strings: A list of strings, each string is a notation of block. + """ + block_strings = [] + for block in blocks_args: + block_strings.append(BlockDecoder._encode_block_string(block)) + return block_strings + + +def efficientnet_params(model_name): + """Map EfficientNet model name to parameter coefficients. + Args: + model_name (str): Model name to be queried. + Returns: + params_dict[model_name]: A (width,depth,res,dropout) tuple. + """ + params_dict = { + 'efficientnet-b0': (1.0, 1.0, 112, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + } + return params_dict[model_name] + + +def efficientnet(width_coefficient=None, + depth_coefficient=None, + image_size=None, + dropout_rate=0.2, + drop_connect_rate=0.2, + num_classes=1000, + include_top=True): + """Create BlockArgs and GlobalParams for efficientnet model. + Args: + width_coefficient (float) + depth_coefficient (float) + image_size (int) + dropout_rate (float) + drop_connect_rate (float) + num_classes (int) + Meaning as the name suggests. + Returns: + blocks_args, global_params. + """ + + blocks_args = [ + 'r1_k3_s11_e1_i32_o16_se0.25', + 'r2_k3_s22_e6_i16_o24_se0.25', + 'r2_k5_s22_e6_i24_o40_se0.25', + 'r3_k3_s22_e6_i40_o80_se0.25', + 'r3_k5_s11_e6_i80_o112_se0.25', + 'r4_k5_s22_e6_i112_o192_se0.25', + 'r1_k3_s11_e6_i192_o320_se0.25', + ] + blocks_args = BlockDecoder.decode(blocks_args) + + global_params = GlobalParams( + width_coefficient=width_coefficient, + depth_coefficient=depth_coefficient, + image_size=image_size, + dropout_rate=dropout_rate, + num_classes=num_classes, + batch_norm_momentum=0.99, + batch_norm_epsilon=1e-3, + drop_connect_rate=drop_connect_rate, + depth_divisor=8, + min_depth=None, + include_top=include_top, + ) + return blocks_args, global_params + + +def get_model_params(model_name, override_params): + """Get the block args and global params for a given model name. + Args: + model_name (str): Model's name. + override_params (dict): A dict to modify global_params. + Returns: + blocks_args, global_params + """ + if model_name.startswith('efficientnet'): + w, d, s, p = efficientnet_params(model_name) + blocks_args, global_params = efficientnet( + width_coefficient=w, + depth_coefficient=d, + dropout_rate=p, + image_size=s) + else: + raise NotImplementedError( + 'model name is not pre-defined: {}'.format(model_name)) + if override_params: + global_params = global_params._replace(**override_params) + return blocks_args, global_params + + +def load_pretrained_weights(model, + model_name, + weights_path=None, + load_fc=True, + advprop=False, + verbose=True): + """Loads pretrained weights from weights path or download using url. + Args: + model (Module): The whole model of efficientnet. + model_name (str): Model name of efficientnet. + weights_path (None or str): + str: path to pretrained weights file on the local disk. + None: use pretrained weights downloaded from the Internet. + load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. + advprop (bool): Whether to load pretrained weights + trained with advprop (valid when weights_path is None). + """ + if isinstance(weights_path, str): + state_dict = torch.load(weights_path) + else: + url_map_ = url_map_advprop if advprop else url_map + state_dict = model_zoo.load_url(url_map_[model_name]) + + if load_fc: + ret = model.load_state_dict(state_dict, strict=False) + assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format( + ret.missing_keys) + else: + state_dict.pop('_fc.weight') + state_dict.pop('_fc.bias') + ret = model.load_state_dict(state_dict, strict=False) + assert set(ret.missing_keys) == set([ + '_fc.weight', '_fc.bias' + ]), 'Missing keys when loading pretrained weights: {}'.format( + ret.missing_keys) + assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format( + ret.unexpected_keys) + + if verbose: + print('Loaded pretrained weights for {}'.format(model_name)) diff --git a/modelscope/models/cv/face_emotion/emotion_infer.py b/modelscope/models/cv/face_emotion/emotion_infer.py new file mode 100644 index 00000000..e3398592 --- /dev/null +++ b/modelscope/models/cv/face_emotion/emotion_infer.py @@ -0,0 +1,67 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import torch +from PIL import Image +from torch import nn +from torchvision import transforms + +from modelscope.utils.logger import get_logger +from .face_alignment.face_align import face_detection_PIL_v2 + +logger = get_logger() + + +def transform_PIL(img_pil): + val_transforms = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + return val_transforms(img_pil) + + +index2AU = [1, 2, 4, 6, 7, 10, 12, 15, 23, 24, 25, 26] +emotion_list = [ + 'Neutral', 'Anger', 'Disgust', 'Fear', 'Happiness', 'Sadness', 'Surprise' +] + + +def inference(image_path, model, face_model, score_thre=0.5, GPU=0): + image = Image.open(image_path).convert('RGB') + + face, bbox = face_detection_PIL_v2(image, face_model) + if bbox is None: + logger.warn('no face detected!') + result = {'emotion_result': None, 'box': None} + return result + + face = transform_PIL(face) + face = face.unsqueeze(0) + if torch.cuda.is_available(): + face = face.cuda(GPU) + logits_AU, logits_emotion = model(face) + logits_AU = torch.sigmoid(logits_AU) + logits_emotion = nn.functional.softmax(logits_emotion, 1) + + _, index_list = logits_emotion.max(1) + emotion_index = index_list[0].data.item() + prob = logits_emotion[0][emotion_index] + if prob > score_thre and emotion_index != 3: + cur_emotion = emotion_list[emotion_index] + else: + cur_emotion = 'Neutral' + + logits_AU = logits_AU[0] + au_ouput = torch.zeros_like(logits_AU) + au_ouput[logits_AU >= score_thre] = 1 + au_ouput[logits_AU < score_thre] = 0 + + au_ouput = au_ouput.int() + + cur_au_list = [] + for idx in range(au_ouput.shape[0]): + if au_ouput[idx] == 1: + au = index2AU[idx] + cur_au_list.append(au) + cur_au_list.sort() + result = (cur_emotion, bbox) + return result diff --git a/modelscope/models/cv/face_emotion/emotion_model.py b/modelscope/models/cv/face_emotion/emotion_model.py new file mode 100644 index 00000000..f8df9c37 --- /dev/null +++ b/modelscope/models/cv/face_emotion/emotion_model.py @@ -0,0 +1,96 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os +import sys + +import torch +import torch.nn.functional as F +from torch import nn + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.face_emotion.efficient import EfficientNet +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@MODELS.register_module(Tasks.face_emotion, module_name=Models.face_emotion) +class EfficientNetForFaceEmotion(TorchModel): + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + self.model = FaceEmotionModel( + name='efficientnet-b0', num_embed=512, num_au=12, num_emotion=7) + + if torch.cuda.is_available(): + self.device = 'cuda' + logger.info('Use GPU') + else: + self.device = 'cpu' + logger.info('Use CPU') + pretrained_params = torch.load( + '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + map_location=self.device) + + state_dict = pretrained_params['model'] + new_state = {} + for k, v in state_dict.items(): + if k.startswith('module.'): + k = k[7:] + new_state[k] = v + + self.model.load_state_dict(new_state) + self.model.eval() + self.model.to(self.device) + + def forward(self, x): + logits_au, logits_emotion = self.model(x) + return logits_au, logits_emotion + + +class FaceEmotionModel(nn.Module): + + def __init__(self, + name='efficientnet-b0', + num_embed=512, + num_au=12, + num_emotion=7): + super(FaceEmotionModel, self).__init__() + self.backbone = EfficientNet.from_pretrained( + name, weights_path=None, advprop=True) + self.average_pool = nn.AdaptiveAvgPool2d(1) + self.embed = nn.Linear(self.backbone._fc.weight.data.shape[1], + num_embed) + self.features = nn.BatchNorm1d(num_embed) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + self.fc_au = nn.Sequential( + nn.Dropout(0.6), + nn.Linear(num_embed, num_au), + ) + self.fc_emotion = nn.Sequential( + nn.Dropout(0.6), + nn.Linear(num_embed, num_emotion), + ) + + def feat_single_img(self, x): + x = self.backbone.extract_features(x) + x = self.average_pool(x) + x = x.flatten(1) + x = self.embed(x) + x = self.features(x) + return x + + def forward(self, x): + x = self.feat_single_img(x) + logits_au = self.fc_au(x) + att_au = torch.sigmoid(logits_au).unsqueeze(-1) + x = x.unsqueeze(1) + emotion_vec_list = torch.matmul(att_au, x) + emotion_vec = emotion_vec_list.sum(1) + logits_emotion = self.fc_emotion(emotion_vec) + return logits_au, logits_emotion diff --git a/modelscope/models/cv/face_emotion/face_alignment/__init__.py b/modelscope/models/cv/face_emotion/face_alignment/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_emotion/face_alignment/face.py b/modelscope/models/cv/face_emotion/face_alignment/face.py new file mode 100644 index 00000000..a362bddc --- /dev/null +++ b/modelscope/models/cv/face_emotion/face_alignment/face.py @@ -0,0 +1,79 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os + +import cv2 +import numpy as np +import tensorflow as tf + + +def init(mod): + PATH_TO_CKPT = mod + net = tf.Graph() + with net.as_default(): + od_graph_def = tf.GraphDef() + config = tf.ConfigProto() + config.gpu_options.per_process_gpu_memory_fraction = 0.6 + with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: + serialized_graph = fid.read() + od_graph_def.ParseFromString(serialized_graph) + tf.import_graph_def(od_graph_def, name='') + sess = tf.Session(graph=net, config=config) + return sess, net + + +def filter_bboxes_confs(shape, + imgsBboxes, + imgsConfs, + single=False, + thresh=0.5): + [w, h] = shape + if single: + bboxes, confs = [], [] + for y in range(len(imgsBboxes)): + if imgsConfs[y] >= thresh: + [x1, y1, x2, y2] = list(imgsBboxes[y]) + x1, y1, x2, y2 = int(w * x1), int(h * y1), int(w * x2), int( + h * y2) + bboxes.append([y1, x1, y2, x2]) + confs.append(imgsConfs[y]) + return bboxes, confs + else: + retImgsBboxes, retImgsConfs = [], [] + for x in range(len(imgsBboxes)): + bboxes, confs = [], [] + for y in range(len(imgsBboxes[x])): + if imgsConfs[x][y] >= thresh: + [x1, y1, x2, y2] = list(imgsBboxes[x][y]) + x1, y1, x2, y2 = int(w * x1), int(h * y1), int( + w * x2), int(h * y2) + bboxes.append([y1, x1, y2, x2]) + confs.append(imgsConfs[x][y]) + retImgsBboxes.append(bboxes) + retImgsConfs.append(confs) + return retImgsBboxes, retImgsConfs + + +def detect(im, sess, net): + image_np = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + image_np_expanded = np.expand_dims(image_np, axis=0) + image_tensor = net.get_tensor_by_name('image_tensor:0') + bboxes = net.get_tensor_by_name('detection_boxes:0') + dConfs = net.get_tensor_by_name('detection_scores:0') + classes = net.get_tensor_by_name('detection_classes:0') + num_detections = net.get_tensor_by_name('num_detections:0') + (bboxes, dConfs, classes, + num_detections) = sess.run([bboxes, dConfs, classes, num_detections], + feed_dict={image_tensor: image_np_expanded}) + w, h, _ = im.shape + bboxes, confs = filter_bboxes_confs([w, h], bboxes[0], dConfs[0], True) + return bboxes, confs + + +class FaceDetector: + + def __init__(self, mod): + self.sess, self.net = init(mod) + + def do_detect(self, im): + bboxes, confs = detect(im, self.sess, self.net) + return bboxes, confs diff --git a/modelscope/models/cv/face_emotion/face_alignment/face_align.py b/modelscope/models/cv/face_emotion/face_alignment/face_align.py new file mode 100644 index 00000000..71282b12 --- /dev/null +++ b/modelscope/models/cv/face_emotion/face_alignment/face_align.py @@ -0,0 +1,59 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os +import sys + +import cv2 +import numpy as np +from PIL import Image, ImageFile + +from .face import FaceDetector + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +def adjust_bx_v2(box, w, h): + x1, y1, x2, y2 = box[0], box[1], box[2], box[3] + box_w = x2 - x1 + box_h = y2 - y1 + delta = abs(box_w - box_h) + if box_w > box_h: + if y1 >= delta: + y1 = y1 - delta + else: + delta_y1 = y1 + y1 = 0 + delta_y2 = delta - delta_y1 + y2 = y2 + delta_y2 if y2 < h - delta_y2 else h - 1 + else: + if x1 >= delta / 2 and x2 <= w - delta / 2: + x1 = x1 - delta / 2 + x2 = x2 + delta / 2 + elif x1 < delta / 2 and x2 <= w - delta / 2: + delta_x1 = x1 + x1 = 0 + delta_x2 = delta - delta_x1 + x2 = x2 + delta_x2 if x2 < w - delta_x2 else w - 1 + elif x1 >= delta / 2 and x2 > w - delta / 2: + delta_x2 = w - x2 + x2 = w - 1 + delta_x1 = delta - x1 + x1 = x1 - delta_x1 if x1 >= delta_x1 else 0 + + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + return [x1, y1, x2, y2] + + +def face_detection_PIL_v2(image, face_model): + crop_size = 112 + face_detector = FaceDetector(face_model) + img = np.array(image) + h, w = img.shape[0:2] + bxs, conf = face_detector.do_detect(img) + bx = bxs[0] + bx = adjust_bx_v2(bx, w, h) + x1, y1, x2, y2 = bx + image = img[y1:y2, x1:x2, :] + img = Image.fromarray(image) + img = img.resize((crop_size, crop_size)) + bx = tuple(bx) + return img, bx diff --git a/modelscope/models/cv/face_generation/op/conv2d_gradfix.py b/modelscope/models/cv/face_generation/op/conv2d_gradfix.py index 661f4fc7..a3aba91f 100755 --- a/modelscope/models/cv/face_generation/op/conv2d_gradfix.py +++ b/modelscope/models/cv/face_generation/op/conv2d_gradfix.py @@ -1,3 +1,5 @@ +# The implementation is adopted from stylegan2-pytorch, made public available under the MIT License +# at https://github.com/rosinality/stylegan2-pytorch/blob/master/op/conv2d_gradfix.py import contextlib import warnings diff --git a/modelscope/models/cv/face_generation/op/fused_act.py b/modelscope/models/cv/face_generation/op/fused_act.py index d6e0c10f..a24f5972 100755 --- a/modelscope/models/cv/face_generation/op/fused_act.py +++ b/modelscope/models/cv/face_generation/op/fused_act.py @@ -1,3 +1,5 @@ +# The implementation is adopted from stylegan2-pytorch, made public available under the MIT License +# t https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py import os import torch diff --git a/modelscope/models/cv/face_generation/op/upfirdn2d.py b/modelscope/models/cv/face_generation/op/upfirdn2d.py index 5a44421d..95c987af 100755 --- a/modelscope/models/cv/face_generation/op/upfirdn2d.py +++ b/modelscope/models/cv/face_generation/op/upfirdn2d.py @@ -1,3 +1,5 @@ +# The implementation is adopted from stylegan2-pytorch, made public available under the MIT License +# at https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py import os from collections import abc diff --git a/modelscope/models/cv/face_generation/stylegan2.py b/modelscope/models/cv/face_generation/stylegan2.py index ff9c83ee..4c650f54 100755 --- a/modelscope/models/cv/face_generation/stylegan2.py +++ b/modelscope/models/cv/face_generation/stylegan2.py @@ -1,3 +1,5 @@ +# The implementation is adopted from stylegan2-pytorch, +# made public available under the MIT License at https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py import functools import math import operator diff --git a/modelscope/models/cv/face_human_hand_detection/__init__.py b/modelscope/models/cv/face_human_hand_detection/__init__.py new file mode 100644 index 00000000..33a5fd2f --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .det_infer import NanoDetForFaceHumanHandDetection + +else: + _import_structure = {'det_infer': ['NanoDetForFaceHumanHandDetection']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/face_human_hand_detection/det_infer.py b/modelscope/models/cv/face_human_hand_detection/det_infer.py new file mode 100644 index 00000000..7a7225ee --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/det_infer.py @@ -0,0 +1,133 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import cv2 +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .one_stage_detector import OneStageDetector + +logger = get_logger() + + +def load_model_weight(model_dir, device): + checkpoint = torch.load( + '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + map_location=device) + state_dict = checkpoint['state_dict'].copy() + for k in checkpoint['state_dict']: + if k.startswith('avg_model.'): + v = state_dict.pop(k) + state_dict[k[4:]] = v + + return state_dict + + +@MODELS.register_module( + Tasks.face_human_hand_detection, + module_name=Models.face_human_hand_detection) +class NanoDetForFaceHumanHandDetection(TorchModel): + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + + self.model = OneStageDetector() + if torch.cuda.is_available(): + self.device = 'cuda' + logger.info('Use GPU ') + else: + self.device = 'cpu' + logger.info('Use CPU') + + self.state_dict = load_model_weight(model_dir, self.device) + self.model.load_state_dict(self.state_dict, strict=False) + self.model.eval() + self.model.to(self.device) + + def forward(self, x): + pred_result = self.model.inference(x) + return pred_result + + +def naive_collate(batch): + elem = batch[0] + if isinstance(elem, dict): + return {key: naive_collate([d[key] for d in batch]) for key in elem} + else: + return batch + + +def get_resize_matrix(raw_shape, dst_shape): + + r_w, r_h = raw_shape + d_w, d_h = dst_shape + Rs = np.eye(3) + + Rs[0, 0] *= d_w / r_w + Rs[1, 1] *= d_h / r_h + return Rs + + +def color_aug_and_norm(meta, mean, std): + img = meta['img'].astype(np.float32) / 255 + mean = np.array(mean, dtype=np.float32).reshape(1, 1, 3) / 255 + std = np.array(std, dtype=np.float32).reshape(1, 1, 3) / 255 + img = (img - mean) / std + meta['img'] = img + return meta + + +def img_process(meta, mean, std): + raw_img = meta['img'] + height = raw_img.shape[0] + width = raw_img.shape[1] + dst_shape = [320, 320] + M = np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) + ResizeM = get_resize_matrix((width, height), dst_shape) + M = ResizeM @ M + img = cv2.warpPerspective(raw_img, M, dsize=tuple(dst_shape)) + meta['img'] = img + meta['warp_matrix'] = M + meta = color_aug_and_norm(meta, mean, std) + return meta + + +def overlay_bbox_cv(dets, class_names, score_thresh): + all_box = [] + for label in dets: + for bbox in dets[label]: + score = bbox[-1] + if score > score_thresh: + x0, y0, x1, y1 = [int(i) for i in bbox[:4]] + all_box.append([label, x0, y0, x1, y1, score]) + all_box.sort(key=lambda v: v[5]) + return all_box + + +mean = [103.53, 116.28, 123.675] +std = [57.375, 57.12, 58.395] +class_names = ['person', 'face', 'hand'] + + +def inference(model, device, img_path): + img_info = {'id': 0} + img = cv2.imread(img_path) + height, width = img.shape[:2] + img_info['height'] = height + img_info['width'] = width + meta = dict(img_info=img_info, raw_img=img, img=img) + + meta = img_process(meta, mean, std) + meta['img'] = torch.from_numpy(meta['img'].transpose(2, 0, 1)).to(device) + meta = naive_collate([meta]) + meta['img'] = (meta['img'][0]).reshape(1, 3, 320, 320) + with torch.no_grad(): + res = model(meta) + result = overlay_bbox_cv(res[0], class_names, score_thresh=0.35) + return result diff --git a/modelscope/models/cv/face_human_hand_detection/ghost_pan.py b/modelscope/models/cv/face_human_hand_detection/ghost_pan.py new file mode 100644 index 00000000..e00de407 --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/ghost_pan.py @@ -0,0 +1,395 @@ +# The implementation here is modified based on nanodet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet + +import math + +import torch +import torch.nn as nn + +from .utils import ConvModule, DepthwiseConvModule, act_layers + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.0).clamp_(0.0, 6.0).div_(6.0) + else: + return F.relu6(x + 3.0) / 6.0 + + +class SqueezeExcite(nn.Module): + + def __init__(self, + in_chs, + se_ratio=0.25, + reduced_base_chs=None, + activation='ReLU', + gate_fn=hard_sigmoid, + divisor=4, + **_): + super(SqueezeExcite, self).__init__() + self.gate_fn = gate_fn + reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, + divisor) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layers(activation) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + + def forward(self, x): + x_se = self.avg_pool(x) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +class GhostModule(nn.Module): + + def __init__(self, + inp, + oup, + kernel_size=1, + ratio=2, + dw_size=3, + stride=1, + activation='ReLU'): + super(GhostModule, self).__init__() + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels * (ratio - 1) + + self.primary_conv = nn.Sequential( + nn.Conv2d( + inp, + init_channels, + kernel_size, + stride, + kernel_size // 2, + bias=False), + nn.BatchNorm2d(init_channels), + act_layers(activation) if activation else nn.Sequential(), + ) + + self.cheap_operation = nn.Sequential( + nn.Conv2d( + init_channels, + new_channels, + dw_size, + 1, + dw_size // 2, + groups=init_channels, + bias=False, + ), + nn.BatchNorm2d(new_channels), + act_layers(activation) if activation else nn.Sequential(), + ) + + def forward(self, x): + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + out = torch.cat([x1, x2], dim=1) + return out + + +class GhostBottleneck(nn.Module): + """Ghost bottleneck w/ optional SE""" + + def __init__( + self, + in_chs, + mid_chs, + out_chs, + dw_kernel_size=3, + stride=1, + activation='ReLU', + se_ratio=0.0, + ): + super(GhostBottleneck, self).__init__() + has_se = se_ratio is not None and se_ratio > 0.0 + self.stride = stride + + # Point-wise expansion + self.ghost1 = GhostModule(in_chs, mid_chs, activation=activation) + + # Depth-wise convolution + if self.stride > 1: + self.conv_dw = nn.Conv2d( + mid_chs, + mid_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size - 1) // 2, + groups=mid_chs, + bias=False, + ) + self.bn_dw = nn.BatchNorm2d(mid_chs) + + if has_se: + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio) + else: + self.se = None + + self.ghost2 = GhostModule(mid_chs, out_chs, activation=None) + + if in_chs == out_chs and self.stride == 1: + self.shortcut = nn.Sequential() + else: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_chs, + in_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size - 1) // 2, + groups=in_chs, + bias=False, + ), + nn.BatchNorm2d(in_chs), + nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_chs), + ) + + def forward(self, x): + residual = x + + x = self.ghost1(x) + + if self.stride > 1: + x = self.conv_dw(x) + x = self.bn_dw(x) + + if self.se is not None: + x = self.se(x) + + x = self.ghost2(x) + + x += self.shortcut(residual) + return x + + +class GhostBlocks(nn.Module): + """Stack of GhostBottleneck used in GhostPAN. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expand (int): Expand ratio of GhostBottleneck. Default: 1. + kernel_size (int): Kernel size of depthwise convolution. Default: 5. + num_blocks (int): Number of GhostBottlecneck blocks. Default: 1. + use_res (bool): Whether to use residual connection. Default: False. + activation (str): Name of activation function. Default: LeakyReLU. + """ + + def __init__( + self, + in_channels, + out_channels, + expand=1, + kernel_size=5, + num_blocks=1, + use_res=False, + activation='LeakyReLU', + ): + super(GhostBlocks, self).__init__() + self.use_res = use_res + if use_res: + self.reduce_conv = ConvModule( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + activation=activation, + ) + blocks = [] + for _ in range(num_blocks): + blocks.append( + GhostBottleneck( + in_channels, + int(out_channels * expand), + out_channels, + dw_kernel_size=kernel_size, + activation=activation, + )) + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + out = self.blocks(x) + if self.use_res: + out = out + self.reduce_conv(x) + return out + + +class GhostPAN(nn.Module): + """Path Aggregation Network with Ghost block. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 3 + use_depthwise (bool): Whether to depthwise separable convolution in + blocks. Default: False + kernel_size (int): Kernel size of depthwise convolution. Default: 5. + expand (int): Expand ratio of GhostBottleneck. Default: 1. + num_blocks (int): Number of GhostBottlecneck blocks. Default: 1. + use_res (bool): Whether to use residual connection. Default: False. + num_extra_level (int): Number of extra conv layers for more feature levels. + Default: 0. + upsample_cfg (dict): Config dict for interpolate layer. + Default: `dict(scale_factor=2, mode='nearest')` + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN') + activation (str): Activation layer name. + Default: LeakyReLU. + """ + + def __init__( + self, + in_channels, + out_channels, + use_depthwise=False, + kernel_size=5, + expand=1, + num_blocks=1, + use_res=False, + num_extra_level=0, + upsample_cfg=dict(scale_factor=2, mode='bilinear'), + norm_cfg=dict(type='BN'), + activation='LeakyReLU', + ): + super(GhostPAN, self).__init__() + assert num_extra_level >= 0 + assert num_blocks >= 1 + self.in_channels = in_channels + self.out_channels = out_channels + + conv = DepthwiseConvModule if use_depthwise else ConvModule + + # build top-down blocks + self.upsample = nn.Upsample(**upsample_cfg) + self.reduce_layers = nn.ModuleList() + for idx in range(len(in_channels)): + self.reduce_layers.append( + ConvModule( + in_channels[idx], + out_channels, + 1, + norm_cfg=norm_cfg, + activation=activation, + )) + self.top_down_blocks = nn.ModuleList() + for idx in range(len(in_channels) - 1, 0, -1): + self.top_down_blocks.append( + GhostBlocks( + out_channels * 2, + out_channels, + expand, + kernel_size=kernel_size, + num_blocks=num_blocks, + use_res=use_res, + activation=activation, + )) + + # build bottom-up blocks + self.downsamples = nn.ModuleList() + self.bottom_up_blocks = nn.ModuleList() + for idx in range(len(in_channels) - 1): + self.downsamples.append( + conv( + out_channels, + out_channels, + kernel_size, + stride=2, + padding=kernel_size // 2, + norm_cfg=norm_cfg, + activation=activation, + )) + self.bottom_up_blocks.append( + GhostBlocks( + out_channels * 2, + out_channels, + expand, + kernel_size=kernel_size, + num_blocks=num_blocks, + use_res=use_res, + activation=activation, + )) + + # extra layers + self.extra_lvl_in_conv = nn.ModuleList() + self.extra_lvl_out_conv = nn.ModuleList() + for i in range(num_extra_level): + self.extra_lvl_in_conv.append( + conv( + out_channels, + out_channels, + kernel_size, + stride=2, + padding=kernel_size // 2, + norm_cfg=norm_cfg, + activation=activation, + )) + self.extra_lvl_out_conv.append( + conv( + out_channels, + out_channels, + kernel_size, + stride=2, + padding=kernel_size // 2, + norm_cfg=norm_cfg, + activation=activation, + )) + + def forward(self, inputs): + """ + Args: + inputs (tuple[Tensor]): input features. + Returns: + tuple[Tensor]: multi level features. + """ + assert len(inputs) == len(self.in_channels) + inputs = [ + reduce(input_x) + for input_x, reduce in zip(inputs, self.reduce_layers) + ] + # top-down path + inner_outs = [inputs[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_heigh = inner_outs[0] + feat_low = inputs[idx - 1] + + inner_outs[0] = feat_heigh + + upsample_feat = self.upsample(feat_heigh) + + inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx]( + torch.cat([upsample_feat, feat_low], 1)) + inner_outs.insert(0, inner_out) + + # bottom-up path + outs = [inner_outs[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = outs[-1] + feat_height = inner_outs[idx + 1] + downsample_feat = self.downsamples[idx](feat_low) + out = self.bottom_up_blocks[idx]( + torch.cat([downsample_feat, feat_height], 1)) + outs.append(out) + + # extra layers + for extra_in_layer, extra_out_layer in zip(self.extra_lvl_in_conv, + self.extra_lvl_out_conv): + outs.append(extra_in_layer(inputs[-1]) + extra_out_layer(outs[-1])) + + return tuple(outs) diff --git a/modelscope/models/cv/face_human_hand_detection/nanodet_plus_head.py b/modelscope/models/cv/face_human_hand_detection/nanodet_plus_head.py new file mode 100644 index 00000000..7f5b50ec --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/nanodet_plus_head.py @@ -0,0 +1,427 @@ +# The implementation here is modified based on nanodet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet + +import math + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.ops import nms + +from .utils import ConvModule, DepthwiseConvModule + + +class Integral(nn.Module): + """A fixed layer for calculating integral result from distribution. + This layer calculates the target location by :math: `sum{P(y_i) * y_i}`, + P(y_i) denotes the softmax vector that represents the discrete distribution + y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max} + Args: + reg_max (int): The maximal value of the discrete set. Default: 16. You + may want to reset it according to your new dataset or related + settings. + """ + + def __init__(self, reg_max=16): + super(Integral, self).__init__() + self.reg_max = reg_max + self.register_buffer('project', + torch.linspace(0, self.reg_max, self.reg_max + 1)) + + def forward(self, x): + """Forward feature from the regression head to get integral result of + bounding box location. + Args: + x (Tensor): Features of the regression head, shape (N, 4*(n+1)), + n is self.reg_max. + Returns: + x (Tensor): Integral result of box locations, i.e., distance + offsets from the box center in four directions, shape (N, 4). + """ + shape = x.size() + x = F.softmax(x.reshape(*shape[:-1], 4, self.reg_max + 1), dim=-1) + x = F.linear(x, self.project.type_as(x)).reshape(*shape[:-1], 4) + return x + + +def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): + """Performs non-maximum suppression in a batched fashion. + Modified from https://github.com/pytorch/vision/blob + /505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39. + In order to perform NMS independently per class, we add an offset to all + the boxes. The offset is dependent only on the class idx, and is large + enough so that boxes from different classes do not overlap. + Arguments: + boxes (torch.Tensor): boxes in shape (N, 4). + scores (torch.Tensor): scores in shape (N, ). + idxs (torch.Tensor): each index value correspond to a bbox cluster, + and NMS will not be applied between elements of different idxs, + shape (N, ). + nms_cfg (dict): specify nms type and other parameters like iou_thr. + Possible keys includes the following. + - iou_thr (float): IoU threshold used for NMS. + - split_thr (float): threshold number of boxes. In some cases the + number of boxes is large (e.g., 200k). To avoid OOM during + training, the users could set `split_thr` to a small value. + If the number of boxes is greater than the threshold, it will + perform NMS on each group of boxes separately and sequentially. + Defaults to 10000. + class_agnostic (bool): if true, nms is class agnostic, + i.e. IoU thresholding happens over all boxes, + regardless of the predicted class. + Returns: + tuple: kept dets and indice. + """ + nms_cfg_ = nms_cfg.copy() + class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic) + if class_agnostic: + boxes_for_nms = boxes + else: + max_coordinate = boxes.max() + offsets = idxs.to(boxes) * (max_coordinate + 1) + boxes_for_nms = boxes + offsets[:, None] + nms_cfg_.pop('type', 'nms') + split_thr = nms_cfg_.pop('split_thr', 10000) + if len(boxes_for_nms) < split_thr: + keep = nms(boxes_for_nms, scores, **nms_cfg_) + boxes = boxes[keep] + scores = scores[keep] + else: + total_mask = scores.new_zeros(scores.size(), dtype=torch.bool) + for id in torch.unique(idxs): + mask = (idxs == id).nonzero(as_tuple=False).view(-1) + keep = nms(boxes_for_nms[mask], scores[mask], **nms_cfg_) + total_mask[mask[keep]] = True + + keep = total_mask.nonzero(as_tuple=False).view(-1) + keep = keep[scores[keep].argsort(descending=True)] + boxes = boxes[keep] + scores = scores[keep] + + return torch.cat([boxes, scores[:, None]], -1), keep + + +def multiclass_nms(multi_bboxes, + multi_scores, + score_thr, + nms_cfg, + max_num=-1, + score_factors=None): + """NMS for multi-class bboxes. + + Args: + multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) + multi_scores (Tensor): shape (n, #class), where the last column + contains scores of the background class, but this will be ignored. + score_thr (float): bbox threshold, bboxes with scores lower than it + will not be considered. + nms_thr (float): NMS IoU threshold + max_num (int): if there are more than max_num bboxes after NMS, + only top max_num will be kept. + score_factors (Tensor): The factors multiplied to scores before + applying NMS + + Returns: + tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels \ + are 0-based. + """ + num_classes = multi_scores.size(1) - 1 + if multi_bboxes.shape[1] > 4: + bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) + else: + bboxes = multi_bboxes[:, None].expand( + multi_scores.size(0), num_classes, 4) + scores = multi_scores[:, :-1] + + valid_mask = scores > score_thr + + bboxes = torch.masked_select( + bboxes, + torch.stack((valid_mask, valid_mask, valid_mask, valid_mask), + -1)).view(-1, 4) + if score_factors is not None: + scores = scores * score_factors[:, None] + scores = torch.masked_select(scores, valid_mask) + labels = valid_mask.nonzero(as_tuple=False)[:, 1] + + if bboxes.numel() == 0: + bboxes = multi_bboxes.new_zeros((0, 5)) + labels = multi_bboxes.new_zeros((0, ), dtype=torch.long) + + if torch.onnx.is_in_onnx_export(): + raise RuntimeError('[ONNX Error] Can not record NMS ' + 'as it has not been executed this time') + return bboxes, labels + + dets, keep = batched_nms(bboxes, scores, labels, nms_cfg) + + if max_num > 0: + dets = dets[:max_num] + keep = keep[:max_num] + + return dets, labels[keep] + + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[..., 0] - distance[..., 0] + y1 = points[..., 1] - distance[..., 1] + x2 = points[..., 0] + distance[..., 2] + y2 = points[..., 1] + distance[..., 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return torch.stack([x1, y1, x2, y2], -1) + + +def warp_boxes(boxes, M, width, height): + n = len(boxes) + if n: + xy = np.ones((n * 4, 3)) + xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) + xy = xy @ M.T + xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + xy = np.concatenate( + (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width) + xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height) + return xy.astype(np.float32) + else: + return boxes + + +class NanoDetPlusHead(nn.Module): + """Detection head used in NanoDet-Plus. + + Args: + num_classes (int): Number of categories excluding the background + category. + loss (dict): Loss config. + input_channel (int): Number of channels of the input feature. + feat_channels (int): Number of channels of the feature. + Default: 96. + stacked_convs (int): Number of conv layers in the stacked convs. + Default: 2. + kernel_size (int): Size of the convolving kernel. Default: 5. + strides (list[int]): Strides of input multi-level feature maps. + Default: [8, 16, 32]. + conv_type (str): Type of the convolution. + Default: "DWConv". + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN'). + reg_max (int): The maximal value of the discrete set. Default: 7. + activation (str): Type of activation function. Default: "LeakyReLU". + assigner_cfg (dict): Config dict of the assigner. Default: dict(topk=13). + """ + + def __init__(self, + num_classes, + input_channel, + feat_channels=96, + stacked_convs=2, + kernel_size=5, + strides=[8, 16, 32], + conv_type='DWConv', + norm_cfg=dict(type='BN'), + reg_max=7, + activation='LeakyReLU', + assigner_cfg=dict(topk=13), + **kwargs): + super(NanoDetPlusHead, self).__init__() + self.num_classes = num_classes + self.in_channels = input_channel + self.feat_channels = feat_channels + self.stacked_convs = stacked_convs + self.kernel_size = kernel_size + self.strides = strides + self.reg_max = reg_max + self.activation = activation + self.ConvModule = ConvModule if conv_type == 'Conv' else DepthwiseConvModule + + self.norm_cfg = norm_cfg + self.distribution_project = Integral(self.reg_max) + + self._init_layers() + + def _init_layers(self): + self.cls_convs = nn.ModuleList() + for _ in self.strides: + cls_convs = self._buid_not_shared_head() + self.cls_convs.append(cls_convs) + + self.gfl_cls = nn.ModuleList([ + nn.Conv2d( + self.feat_channels, + self.num_classes + 4 * (self.reg_max + 1), + 1, + padding=0, + ) for _ in self.strides + ]) + + def _buid_not_shared_head(self): + cls_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + cls_convs.append( + self.ConvModule( + chn, + self.feat_channels, + self.kernel_size, + stride=1, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg, + bias=self.norm_cfg is None, + activation=self.activation, + )) + return cls_convs + + def forward(self, feats): + if torch.onnx.is_in_onnx_export(): + return self._forward_onnx(feats) + outputs = [] + for feat, cls_convs, gfl_cls in zip( + feats, + self.cls_convs, + self.gfl_cls, + ): + for conv in cls_convs: + feat = conv(feat) + output = gfl_cls(feat) + outputs.append(output.flatten(start_dim=2)) + outputs = torch.cat(outputs, dim=2).permute(0, 2, 1) + return outputs + + def post_process(self, preds, meta): + """Prediction results post processing. Decode bboxes and rescale + to original image size. + Args: + preds (Tensor): Prediction output. + meta (dict): Meta info. + """ + cls_scores, bbox_preds = preds.split( + [self.num_classes, 4 * (self.reg_max + 1)], dim=-1) + result_list = self.get_bboxes(cls_scores, bbox_preds, meta) + det_results = {} + warp_matrixes = ( + meta['warp_matrix'] + if isinstance(meta['warp_matrix'], list) else meta['warp_matrix']) + img_heights = ( + meta['img_info']['height'].cpu().numpy() if isinstance( + meta['img_info']['height'], torch.Tensor) else + meta['img_info']['height']) + img_widths = ( + meta['img_info']['width'].cpu().numpy() if isinstance( + meta['img_info']['width'], torch.Tensor) else + meta['img_info']['width']) + img_ids = ( + meta['img_info']['id'].cpu().numpy() if isinstance( + meta['img_info']['id'], torch.Tensor) else + meta['img_info']['id']) + + for result, img_width, img_height, img_id, warp_matrix in zip( + result_list, img_widths, img_heights, img_ids, warp_matrixes): + det_result = {} + det_bboxes, det_labels = result + det_bboxes = det_bboxes.detach().cpu().numpy() + det_bboxes[:, :4] = warp_boxes(det_bboxes[:, :4], + np.linalg.inv(warp_matrix), + img_width, img_height) + classes = det_labels.detach().cpu().numpy() + for i in range(self.num_classes): + inds = classes == i + det_result[i] = np.concatenate( + [ + det_bboxes[inds, :4].astype(np.float32), + det_bboxes[inds, 4:5].astype(np.float32), + ], + axis=1, + ).tolist() + det_results[img_id] = det_result + return det_results + + def get_bboxes(self, cls_preds, reg_preds, img_metas): + """Decode the outputs to bboxes. + Args: + cls_preds (Tensor): Shape (num_imgs, num_points, num_classes). + reg_preds (Tensor): Shape (num_imgs, num_points, 4 * (regmax + 1)). + img_metas (dict): Dict of image info. + + Returns: + results_list (list[tuple]): List of detection bboxes and labels. + """ + device = cls_preds.device + b = cls_preds.shape[0] + input_height, input_width = img_metas['img'].shape[2:] + input_shape = (input_height, input_width) + + featmap_sizes = [(math.ceil(input_height / stride), + math.ceil(input_width) / stride) + for stride in self.strides] + mlvl_center_priors = [ + self.get_single_level_center_priors( + b, + featmap_sizes[i], + stride, + dtype=torch.float32, + device=device, + ) for i, stride in enumerate(self.strides) + ] + center_priors = torch.cat(mlvl_center_priors, dim=1) + dis_preds = self.distribution_project(reg_preds) * center_priors[..., + 2, + None] + bboxes = distance2bbox( + center_priors[..., :2], dis_preds, max_shape=input_shape) + scores = cls_preds.sigmoid() + result_list = [] + for i in range(b): + score, bbox = scores[i], bboxes[i] + padding = score.new_zeros(score.shape[0], 1) + score = torch.cat([score, padding], dim=1) + results = multiclass_nms( + bbox, + score, + score_thr=0.05, + nms_cfg=dict(type='nms', iou_threshold=0.6), + max_num=100, + ) + result_list.append(results) + return result_list + + def get_single_level_center_priors(self, batch_size, featmap_size, stride, + dtype, device): + """Generate centers of a single stage feature map. + Args: + batch_size (int): Number of images in one batch. + featmap_size (tuple[int]): height and width of the feature map + stride (int): down sample stride of the feature map + dtype (obj:`torch.dtype`): data type of the tensors + device (obj:`torch.device`): device of the tensors + Return: + priors (Tensor): center priors of a single level feature map. + """ + h, w = featmap_size + x_range = (torch.arange(w, dtype=dtype, device=device)) * stride + y_range = (torch.arange(h, dtype=dtype, device=device)) * stride + y, x = torch.meshgrid(y_range, x_range) + y = y.flatten() + x = x.flatten() + strides = x.new_full((x.shape[0], ), stride) + proiors = torch.stack([x, y, strides, strides], dim=-1) + return proiors.unsqueeze(0).repeat(batch_size, 1, 1) diff --git a/modelscope/models/cv/face_human_hand_detection/one_stage_detector.py b/modelscope/models/cv/face_human_hand_detection/one_stage_detector.py new file mode 100644 index 00000000..c1d0a52f --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/one_stage_detector.py @@ -0,0 +1,64 @@ +# The implementation here is modified based on nanodet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet + +import torch +import torch.nn as nn + +from .ghost_pan import GhostPAN +from .nanodet_plus_head import NanoDetPlusHead +from .shufflenetv2 import ShuffleNetV2 + + +class OneStageDetector(nn.Module): + + def __init__(self): + super(OneStageDetector, self).__init__() + self.backbone = ShuffleNetV2( + model_size='1.0x', + out_stages=(2, 3, 4), + with_last_conv=False, + kernal_size=3, + activation='LeakyReLU', + pretrain=False) + self.fpn = GhostPAN( + in_channels=[116, 232, 464], + out_channels=96, + use_depthwise=True, + kernel_size=5, + expand=1, + num_blocks=1, + use_res=False, + num_extra_level=1, + upsample_cfg=dict(scale_factor=2, mode='bilinear'), + norm_cfg=dict(type='BN'), + activation='LeakyReLU') + self.head = NanoDetPlusHead( + num_classes=3, + input_channel=96, + feat_channels=96, + stacked_convs=2, + kernel_size=5, + strides=[8, 16, 32, 64], + conv_type='DWConv', + norm_cfg=dict(type='BN'), + reg_max=7, + activation='LeakyReLU', + assigner_cfg=dict(topk=13)) + self.epoch = 0 + + def forward(self, x): + x = self.backbone(x) + if hasattr(self, 'fpn'): + x = self.fpn(x) + if hasattr(self, 'head'): + x = self.head(x) + return x + + def inference(self, meta): + with torch.no_grad(): + torch.cuda.synchronize() + preds = self(meta['img']) + torch.cuda.synchronize() + results = self.head.post_process(preds, meta) + torch.cuda.synchronize() + return results diff --git a/modelscope/models/cv/face_human_hand_detection/shufflenetv2.py b/modelscope/models/cv/face_human_hand_detection/shufflenetv2.py new file mode 100644 index 00000000..7f4dfc2a --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/shufflenetv2.py @@ -0,0 +1,182 @@ +# The implementation here is modified based on nanodet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet + +import torch +import torch.nn as nn + +from .utils import act_layers + + +def channel_shuffle(x, groups): + batchsize, num_channels, height, width = x.data.size() + channels_per_group = num_channels // groups + + x = x.view(batchsize, groups, channels_per_group, height, width) + + x = torch.transpose(x, 1, 2).contiguous() + + x = x.view(batchsize, -1, height, width) + + return x + + +class ShuffleV2Block(nn.Module): + + def __init__(self, inp, oup, stride, activation='ReLU'): + super(ShuffleV2Block, self).__init__() + + if not (1 <= stride <= 3): + raise ValueError('illegal stride value') + self.stride = stride + + branch_features = oup // 2 + assert (self.stride != 1) or (inp == branch_features << 1) + + if self.stride > 1: + self.branch1 = nn.Sequential( + self.depthwise_conv( + inp, inp, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2d(inp), + nn.Conv2d( + inp, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False), + nn.BatchNorm2d(branch_features), + act_layers(activation), + ) + else: + self.branch1 = nn.Sequential() + + self.branch2 = nn.Sequential( + nn.Conv2d( + inp if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.BatchNorm2d(branch_features), + act_layers(activation), + self.depthwise_conv( + branch_features, + branch_features, + kernel_size=3, + stride=self.stride, + padding=1, + ), + nn.BatchNorm2d(branch_features), + nn.Conv2d( + branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.BatchNorm2d(branch_features), + act_layers(activation), + ) + + @staticmethod + def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): + return nn.Conv2d( + i, o, kernel_size, stride, padding, bias=bias, groups=i) + + def forward(self, x): + if self.stride == 1: + x1, x2 = x.chunk(2, dim=1) + out = torch.cat((x1, self.branch2(x2)), dim=1) + else: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + + out = channel_shuffle(out, 2) + + return out + + +class ShuffleNetV2(nn.Module): + + def __init__( + self, + model_size='1.5x', + out_stages=(2, 3, 4), + with_last_conv=False, + kernal_size=3, + activation='ReLU', + pretrain=True, + ): + super(ShuffleNetV2, self).__init__() + assert set(out_stages).issubset((2, 3, 4)) + + print('model size is ', model_size) + + self.stage_repeats = [4, 8, 4] + self.model_size = model_size + self.out_stages = out_stages + self.with_last_conv = with_last_conv + self.kernal_size = kernal_size + self.activation = activation + if model_size == '0.5x': + self._stage_out_channels = [24, 48, 96, 192, 1024] + elif model_size == '1.0x': + self._stage_out_channels = [24, 116, 232, 464, 1024] + elif model_size == '1.5x': + self._stage_out_channels = [24, 176, 352, 704, 1024] + elif model_size == '2.0x': + self._stage_out_channels = [24, 244, 488, 976, 2048] + else: + raise NotImplementedError + + # building first layer + input_channels = 3 + output_channels = self._stage_out_channels[0] + self.conv1 = nn.Sequential( + nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), + nn.BatchNorm2d(output_channels), + act_layers(activation), + ) + input_channels = output_channels + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] + for name, repeats, output_channels in zip( + stage_names, self.stage_repeats, self._stage_out_channels[1:]): + seq = [ + ShuffleV2Block( + input_channels, output_channels, 2, activation=activation) + ] + for i in range(repeats - 1): + seq.append( + ShuffleV2Block( + output_channels, + output_channels, + 1, + activation=activation)) + setattr(self, name, nn.Sequential(*seq)) + input_channels = output_channels + output_channels = self._stage_out_channels[-1] + if self.with_last_conv: + conv5 = nn.Sequential( + nn.Conv2d( + input_channels, output_channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(output_channels), + act_layers(activation), + ) + self.stage4.add_module('conv5', conv5) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + output = [] + + for i in range(2, 5): + stage = getattr(self, 'stage{}'.format(i)) + x = stage(x) + if i in self.out_stages: + output.append(x) + return tuple(output) diff --git a/modelscope/models/cv/face_human_hand_detection/utils.py b/modelscope/models/cv/face_human_hand_detection/utils.py new file mode 100644 index 00000000..f989c164 --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/utils.py @@ -0,0 +1,277 @@ +# The implementation here is modified based on nanodet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet + +import torch +import torch.nn as nn + +activations = { + 'ReLU': nn.ReLU, + 'LeakyReLU': nn.LeakyReLU, + 'ReLU6': nn.ReLU6, + 'SELU': nn.SELU, + 'ELU': nn.ELU, + 'GELU': nn.GELU, + 'PReLU': nn.PReLU, + 'SiLU': nn.SiLU, + 'HardSwish': nn.Hardswish, + 'Hardswish': nn.Hardswish, + None: nn.Identity, +} + + +def act_layers(name): + assert name in activations.keys() + if name == 'LeakyReLU': + return nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif name == 'GELU': + return nn.GELU() + elif name == 'PReLU': + return nn.PReLU() + else: + return activations[name](inplace=True) + + +norm_cfg = { + 'BN': ('bn', nn.BatchNorm2d), + 'SyncBN': ('bn', nn.SyncBatchNorm), + 'GN': ('gn', nn.GroupNorm), +} + + +def build_norm_layer(cfg, num_features, postfix=''): + """Build normalization layer + + Args: + cfg (dict): cfg should contain: + type (str): identify norm layer type. + layer args: args needed to instantiate a norm layer. + requires_grad (bool): [optional] whether stop gradient updates + num_features (int): number of channels from input. + postfix (int, str): appended into norm abbreviation to + create named layer. + + Returns: + name (str): abbreviation + postfix + layer (nn.Module): created norm layer + """ + assert isinstance(cfg, dict) and 'type' in cfg + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in norm_cfg: + raise KeyError('Unrecognized norm type {}'.format(layer_type)) + else: + abbr, norm_layer = norm_cfg[layer_type] + if norm_layer is None: + raise NotImplementedError + + assert isinstance(postfix, (int, str)) + name = abbr + str(postfix) + + requires_grad = cfg_.pop('requires_grad', True) + cfg_.setdefault('eps', 1e-5) + if layer_type != 'GN': + layer = norm_layer(num_features, **cfg_) + if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'): + layer._specify_ddp_gpu_num(1) + else: + assert 'num_groups' in cfg_ + layer = norm_layer(num_channels=num_features, **cfg_) + + for param in layer.parameters(): + param.requires_grad = requires_grad + + return name, layer + + +class ConvModule(nn.Module): + """A conv block that contains conv/norm/activation layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + conv_cfg (dict): Config dict for convolution layer. + norm_cfg (dict): Config dict for normalization layer. + activation (str): activation layer, "ReLU" by default. + inplace (bool): Whether to use inplace mode for activation. + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias='auto', + conv_cfg=None, + norm_cfg=None, + activation='ReLU', + inplace=True, + order=('conv', 'norm', 'act'), + ): + super(ConvModule, self).__init__() + assert conv_cfg is None or isinstance(conv_cfg, dict) + assert norm_cfg is None or isinstance(norm_cfg, dict) + assert activation is None or isinstance(activation, str) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.activation = activation + self.inplace = inplace + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 3 + assert set(order) == {'conv', 'norm', 'act'} + + self.with_norm = norm_cfg is not None + if bias == 'auto': + bias = False if self.with_norm else True + self.with_bias = bias + + if self.with_norm and self.with_bias: + warnings.warn('ConvModule has norm and bias at the same time') + + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + self.in_channels = self.conv.in_channels + self.out_channels = self.conv.out_channels + self.kernel_size = self.conv.kernel_size + self.stride = self.conv.stride + self.padding = self.conv.padding + self.dilation = self.conv.dilation + self.transposed = self.conv.transposed + self.output_padding = self.conv.output_padding + self.groups = self.conv.groups + + if self.with_norm: + if order.index('norm') > order.index('conv'): + norm_channels = out_channels + else: + norm_channels = in_channels + self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) + self.add_module(self.norm_name, norm) + else: + self.norm_name = None + + if self.activation: + self.act = act_layers(self.activation) + + @property + def norm(self): + if self.norm_name: + return getattr(self, self.norm_name) + else: + return None + + def forward(self, x, norm=True): + for layer in self.order: + if layer == 'conv': + x = self.conv(x) + elif layer == 'norm' and norm and self.with_norm: + x = self.norm(x) + elif layer == 'act' and self.activation: + x = self.act(x) + return x + + +class DepthwiseConvModule(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + bias='auto', + norm_cfg=dict(type='BN'), + activation='ReLU', + inplace=True, + order=('depthwise', 'dwnorm', 'act', 'pointwise', 'pwnorm', 'act'), + ): + super(DepthwiseConvModule, self).__init__() + assert activation is None or isinstance(activation, str) + self.activation = activation + self.inplace = inplace + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 6 + assert set(order) == { + 'depthwise', + 'dwnorm', + 'act', + 'pointwise', + 'pwnorm', + 'act', + } + + self.with_norm = norm_cfg is not None + if bias == 'auto': + bias = False if self.with_norm else True + self.with_bias = bias + + if self.with_norm and self.with_bias: + warnings.warn('ConvModule has norm and bias at the same time') + + self.depthwise = nn.Conv2d( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=bias, + ) + self.pointwise = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias) + + self.in_channels = self.depthwise.in_channels + self.out_channels = self.pointwise.out_channels + self.kernel_size = self.depthwise.kernel_size + self.stride = self.depthwise.stride + self.padding = self.depthwise.padding + self.dilation = self.depthwise.dilation + self.transposed = self.depthwise.transposed + self.output_padding = self.depthwise.output_padding + + if self.with_norm: + _, self.dwnorm = build_norm_layer(norm_cfg, in_channels) + _, self.pwnorm = build_norm_layer(norm_cfg, out_channels) + + if self.activation: + self.act = act_layers(self.activation) + + def forward(self, x, norm=True): + for layer_name in self.order: + if layer_name != 'act': + layer = self.__getattr__(layer_name) + x = layer(x) + elif layer_name == 'act' and self.activation: + x = self.act(x) + return x diff --git a/modelscope/models/cv/image_colorization/unet.py b/modelscope/models/cv/image_colorization/unet.py index 8123651e..19f6ab62 100644 --- a/modelscope/models/cv/image_colorization/unet.py +++ b/modelscope/models/cv/image_colorization/unet.py @@ -1,3 +1,5 @@ +# The implementation here is modified based on DeOldify, originally MIT License +# and publicly available at https://github.com/jantic/DeOldify/blob/master/deoldify/unet.py import numpy as np import torch import torch.nn as nn diff --git a/modelscope/models/cv/image_colorization/utils.py b/modelscope/models/cv/image_colorization/utils.py index 03473f90..b8968aa0 100644 --- a/modelscope/models/cv/image_colorization/utils.py +++ b/modelscope/models/cv/image_colorization/utils.py @@ -1,3 +1,5 @@ +# The implementation here is modified based on DeOldify, originally MIT License and +# publicly available at https://github.com/jantic/DeOldify/blob/master/fastai/callbacks/hooks.py import functools from enum import Enum diff --git a/modelscope/models/cv/image_portrait_enhancement/align_faces.py b/modelscope/models/cv/image_portrait_enhancement/align_faces.py index 776b06d8..e6852f8c 100755 --- a/modelscope/models/cv/image_portrait_enhancement/align_faces.py +++ b/modelscope/models/cv/image_portrait_enhancement/align_faces.py @@ -1,3 +1,5 @@ +# Part of the implementation is borrowed and modified from Face-Alignment, +# publicly available at https://github.com/foamliu/Face-Alignment/blob/master/align_faces.py import cv2 import numpy as np from skimage import transform as trans diff --git a/modelscope/models/cv/image_portrait_enhancement/eqface/fqa.py b/modelscope/models/cv/image_portrait_enhancement/eqface/fqa.py index fe4081a4..51f2206e 100755 --- a/modelscope/models/cv/image_portrait_enhancement/eqface/fqa.py +++ b/modelscope/models/cv/image_portrait_enhancement/eqface/fqa.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os import cv2 diff --git a/modelscope/models/cv/image_portrait_enhancement/eqface/model_resnet.py b/modelscope/models/cv/image_portrait_enhancement/eqface/model_resnet.py index ea3c4f2a..e0e8e9d5 100644 --- a/modelscope/models/cv/image_portrait_enhancement/eqface/model_resnet.py +++ b/modelscope/models/cv/image_portrait_enhancement/eqface/model_resnet.py @@ -1,3 +1,5 @@ +# The implementation is adopted from FaceQuality, made publicly available under the MIT License +# at https://github.com/deepcam-cn/FaceQuality/blob/master/models/model_resnet.py import torch from torch import nn diff --git a/modelscope/models/cv/image_portrait_enhancement/gpen.py b/modelscope/models/cv/image_portrait_enhancement/gpen.py index 2e21dbc0..86009a41 100755 --- a/modelscope/models/cv/image_portrait_enhancement/gpen.py +++ b/modelscope/models/cv/image_portrait_enhancement/gpen.py @@ -1,3 +1,5 @@ +# The GPEN implementation is also open-sourced by the authors, +# and available at https://github.com/yangxy/GPEN/blob/main/face_model/gpen_model.py import functools import itertools import math diff --git a/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py b/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py index 3250d393..3650ac7b 100644 --- a/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py +++ b/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import math import os.path as osp from copy import deepcopy diff --git a/modelscope/models/cv/image_portrait_enhancement/losses/helpers.py b/modelscope/models/cv/image_portrait_enhancement/losses/helpers.py index 35ca202f..86f6f227 100644 --- a/modelscope/models/cv/image_portrait_enhancement/losses/helpers.py +++ b/modelscope/models/cv/image_portrait_enhancement/losses/helpers.py @@ -1,3 +1,5 @@ +# The implementation is adopted from InsightFace_Pytorch, +# made publicly available under the MIT License at https://github.com/TreB1eN/InsightFace_Pytorch/blob/master/model.py from collections import namedtuple import torch diff --git a/modelscope/models/cv/image_portrait_enhancement/losses/losses.py b/modelscope/models/cv/image_portrait_enhancement/losses/losses.py index 8934eee7..0f5198c3 100644 --- a/modelscope/models/cv/image_portrait_enhancement/losses/losses.py +++ b/modelscope/models/cv/image_portrait_enhancement/losses/losses.py @@ -1,3 +1,5 @@ +# The GPEN implementation is also open-sourced by the authors, +# and available at https://github.com/yangxy/GPEN/tree/main/training/loss/id_loss.py import torch import torch.nn as nn import torch.nn.functional as F diff --git a/modelscope/models/cv/image_portrait_enhancement/losses/model_irse.py b/modelscope/models/cv/image_portrait_enhancement/losses/model_irse.py index 3b87d7fd..00dc7c52 100644 --- a/modelscope/models/cv/image_portrait_enhancement/losses/model_irse.py +++ b/modelscope/models/cv/image_portrait_enhancement/losses/model_irse.py @@ -1,3 +1,5 @@ +# The implementation is adopted from InsightFace_Pytorch, +# made publicly available under the MIT License at https://github.com/TreB1eN/InsightFace_Pytorch/blob/master/model.py from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, Module, PReLU, Sequential) diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/detection.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/detection.py index c294438a..7ad780a8 100755 --- a/modelscope/models/cv/image_portrait_enhancement/retinaface/detection.py +++ b/modelscope/models/cv/image_portrait_enhancement/retinaface/detection.py @@ -1,3 +1,5 @@ +# The GPEN implementation is also open-sourced by the authors, +# and available at https://github.com/yangxy/GPEN/blob/main/face_detect/retinaface_detection.py import os import cv2 diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/models/net.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/net.py index 0546e0bb..24451e96 100755 --- a/modelscope/models/cv/image_portrait_enhancement/retinaface/models/net.py +++ b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/net.py @@ -1,3 +1,5 @@ +# The implementation is adopted from Pytorch_Retinaface, made pubicly available under the MIT License +# at https://github.com/biubug6/Pytorch_Retinaface/tree/master/models/net.py import time import torch diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/models/retinaface.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/retinaface.py index af1d706d..64d95971 100755 --- a/modelscope/models/cv/image_portrait_enhancement/retinaface/models/retinaface.py +++ b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/retinaface.py @@ -1,3 +1,5 @@ +# The implementation is adopted from Pytorch_Retinaface, made pubicly available under the MIT License +# at https://github.com/biubug6/Pytorch_Retinaface/tree/master/models/retinaface.py from collections import OrderedDict import torch diff --git a/modelscope/models/cv/image_to_image_generation/model.py b/modelscope/models/cv/image_to_image_generation/model.py index 37479b43..94e5dd7b 100644 --- a/modelscope/models/cv/image_to_image_generation/model.py +++ b/modelscope/models/cv/image_to_image_generation/model.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import math import torch diff --git a/modelscope/models/cv/image_to_image_generation/models/autoencoder.py b/modelscope/models/cv/image_to_image_generation/models/autoencoder.py index 181472de..dce256f6 100644 --- a/modelscope/models/cv/image_to_image_generation/models/autoencoder.py +++ b/modelscope/models/cv/image_to_image_generation/models/autoencoder.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import math import torch diff --git a/modelscope/models/cv/image_to_image_generation/models/clip.py b/modelscope/models/cv/image_to_image_generation/models/clip.py index 35d9d882..d3dd22b4 100644 --- a/modelscope/models/cv/image_to_image_generation/models/clip.py +++ b/modelscope/models/cv/image_to_image_generation/models/clip.py @@ -1,3 +1,5 @@ +# Part of the implementation is borrowed and modified from CLIP, publicly avaialbe at https://github.com/openai/CLIP. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import math import torch diff --git a/modelscope/models/cv/image_to_image_generation/ops/diffusion.py b/modelscope/models/cv/image_to_image_generation/ops/diffusion.py index bcbb6402..b8ffbbbb 100644 --- a/modelscope/models/cv/image_to_image_generation/ops/diffusion.py +++ b/modelscope/models/cv/image_to_image_generation/ops/diffusion.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import math import torch diff --git a/modelscope/models/cv/image_to_image_generation/ops/losses.py b/modelscope/models/cv/image_to_image_generation/ops/losses.py index 23e8d246..46b9540a 100644 --- a/modelscope/models/cv/image_to_image_generation/ops/losses.py +++ b/modelscope/models/cv/image_to_image_generation/ops/losses.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import math import torch diff --git a/modelscope/models/cv/image_to_image_translation/data/transforms.py b/modelscope/models/cv/image_to_image_translation/data/transforms.py index 5376d813..29a25b4b 100644 --- a/modelscope/models/cv/image_to_image_translation/data/transforms.py +++ b/modelscope/models/cv/image_to_image_translation/data/transforms.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import math import random diff --git a/modelscope/models/cv/image_to_image_translation/model_translation.py b/modelscope/models/cv/image_to_image_translation/model_translation.py index 722b175d..f2a9e7db 100644 --- a/modelscope/models/cv/image_to_image_translation/model_translation.py +++ b/modelscope/models/cv/image_to_image_translation/model_translation.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import math import torch diff --git a/modelscope/models/cv/image_to_image_translation/models/autoencoder.py b/modelscope/models/cv/image_to_image_translation/models/autoencoder.py index 181472de..dce256f6 100644 --- a/modelscope/models/cv/image_to_image_translation/models/autoencoder.py +++ b/modelscope/models/cv/image_to_image_translation/models/autoencoder.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import math import torch diff --git a/modelscope/models/cv/image_to_image_translation/models/clip.py b/modelscope/models/cv/image_to_image_translation/models/clip.py index 35d9d882..d3dd22b4 100644 --- a/modelscope/models/cv/image_to_image_translation/models/clip.py +++ b/modelscope/models/cv/image_to_image_translation/models/clip.py @@ -1,3 +1,5 @@ +# Part of the implementation is borrowed and modified from CLIP, publicly avaialbe at https://github.com/openai/CLIP. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import math import torch diff --git a/modelscope/models/cv/image_to_image_translation/ops/apps.py b/modelscope/models/cv/image_to_image_translation/ops/apps.py index ee4be489..39d2e015 100644 --- a/modelscope/models/cv/image_to_image_translation/ops/apps.py +++ b/modelscope/models/cv/image_to_image_translation/ops/apps.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. # APPs that facilitate the use of pretrained neural networks. import os.path as osp diff --git a/modelscope/models/cv/image_to_image_translation/ops/degradation.py b/modelscope/models/cv/image_to_image_translation/ops/degradation.py index c3b3d1df..9061e7be 100644 --- a/modelscope/models/cv/image_to_image_translation/ops/degradation.py +++ b/modelscope/models/cv/image_to_image_translation/ops/degradation.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import math import os import random diff --git a/modelscope/models/cv/image_to_image_translation/ops/diffusion.py b/modelscope/models/cv/image_to_image_translation/ops/diffusion.py index bcbb6402..5ff37dc3 100644 --- a/modelscope/models/cv/image_to_image_translation/ops/diffusion.py +++ b/modelscope/models/cv/image_to_image_translation/ops/diffusion.py @@ -1,3 +1,6 @@ +# Part of the implementation is borrowed and modified from latent-diffusion, +# publicly avaialbe at https://github.com/CompVis/latent-diffusion. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import math import torch diff --git a/modelscope/models/cv/image_to_image_translation/ops/losses.py b/modelscope/models/cv/image_to_image_translation/ops/losses.py index 23e8d246..46b9540a 100644 --- a/modelscope/models/cv/image_to_image_translation/ops/losses.py +++ b/modelscope/models/cv/image_to_image_translation/ops/losses.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import math import torch diff --git a/modelscope/models/cv/image_to_image_translation/ops/metrics.py b/modelscope/models/cv/image_to_image_translation/ops/metrics.py index 4a63c51f..c1023fa0 100644 --- a/modelscope/models/cv/image_to_image_translation/ops/metrics.py +++ b/modelscope/models/cv/image_to_image_translation/ops/metrics.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import numpy as np import scipy.linalg as linalg import torch diff --git a/modelscope/models/cv/image_to_image_translation/ops/random_color.py b/modelscope/models/cv/image_to_image_translation/ops/random_color.py index 97e2f848..75692836 100644 --- a/modelscope/models/cv/image_to_image_translation/ops/random_color.py +++ b/modelscope/models/cv/image_to_image_translation/ops/random_color.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import colorsys import random diff --git a/modelscope/models/cv/image_to_image_translation/ops/random_mask.py b/modelscope/models/cv/image_to_image_translation/ops/random_mask.py index a6b55916..bda1ec11 100644 --- a/modelscope/models/cv/image_to_image_translation/ops/random_mask.py +++ b/modelscope/models/cv/image_to_image_translation/ops/random_mask.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import cv2 import numpy as np diff --git a/modelscope/models/cv/image_to_image_translation/ops/svd.py b/modelscope/models/cv/image_to_image_translation/ops/svd.py index c5173de1..96f7e825 100644 --- a/modelscope/models/cv/image_to_image_translation/ops/svd.py +++ b/modelscope/models/cv/image_to_image_translation/ops/svd.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. r"""SVD of linear degradation matrices described in the paper ``Denoising Diffusion Restoration Models.'' @article{kawar2022denoising, diff --git a/modelscope/models/cv/image_to_image_translation/ops/utils.py b/modelscope/models/cv/image_to_image_translation/ops/utils.py index 3e523f4c..c2aacedc 100644 --- a/modelscope/models/cv/image_to_image_translation/ops/utils.py +++ b/modelscope/models/cv/image_to_image_translation/ops/utils.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import base64 import binascii import hashlib diff --git a/modelscope/models/cv/movie_scene_segmentation/model.py b/modelscope/models/cv/movie_scene_segmentation/model.py index 1232d427..8117961a 100644 --- a/modelscope/models/cv/movie_scene_segmentation/model.py +++ b/modelscope/models/cv/movie_scene_segmentation/model.py @@ -67,7 +67,6 @@ class MovieSceneSegmentationModel(TorchModel): mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) - self.infer_result = {'vid': [], 'sid': [], 'pred': []} sampling_method = self.cfg.dataset.sampling_method.name self.neighbor_size = self.cfg.dataset.sampling_method.params[ sampling_method].neighbor_size @@ -104,6 +103,8 @@ class MovieSceneSegmentationModel(TorchModel): shot_num = len(sids) cnt = shot_num // bs + 1 + infer_sid, infer_pred = [], [] + infer_result = {} for i in range(cnt): start = i * bs end = (i + 1) * bs if (i + 1) * bs < shot_num else shot_num @@ -112,13 +113,14 @@ class MovieSceneSegmentationModel(TorchModel): input_ = torch.stack(input_) outputs = self.shared_step(input_) # shape [b,2] prob = F.softmax(outputs, dim=1) - self.infer_result['sid'].extend(sid_.cpu().detach().numpy()) - self.infer_result['pred'].extend(prob[:, 1].cpu().detach().numpy()) - self.infer_result['pred'] = np.stack(self.infer_result['pred']) + infer_sid.extend(sid_.cpu().detach().numpy()) + infer_pred.extend(prob[:, 1].cpu().detach().numpy()) + infer_result.update({'pred': np.stack(infer_pred)}) + infer_result.update({'sid': infer_sid}) - assert len(self.infer_result['sid']) == len(sids) - assert len(self.infer_result['pred']) == len(inputs) - return self.infer_result + assert len(infer_result['sid']) == len(sids) + assert len(infer_result['pred']) == len(inputs) + return infer_result def shared_step(self, inputs): with torch.no_grad(): @@ -162,11 +164,12 @@ class MovieSceneSegmentationModel(TorchModel): thres = self.cfg.pipeline.save_threshold anno_dict = get_pred_boundary(pred_dict, thres) - scene_dict_lst, scene_list = pred2scene(self.shot2keyf, anno_dict) + scene_dict_lst, scene_list, shot_num, shot_dict_lst = pred2scene( + self.shot2keyf, anno_dict) if self.cfg.pipeline.save_split_scene: re_dir = scene2video(inputs['input_video_pth'], scene_list, thres) print(f'Split scene video saved to {re_dir}') - return len(scene_list), scene_dict_lst + return len(scene_list), scene_dict_lst, shot_num, shot_dict_lst def preprocess(self, inputs): logger.info('Begin shot detect......') diff --git a/modelscope/models/cv/movie_scene_segmentation/utils/save_op.py b/modelscope/models/cv/movie_scene_segmentation/utils/save_op.py index b350ff13..3339e1a3 100644 --- a/modelscope/models/cv/movie_scene_segmentation/utils/save_op.py +++ b/modelscope/models/cv/movie_scene_segmentation/utils/save_op.py @@ -22,15 +22,23 @@ def pred2scene(shot2keyf, anno_dict): scene_list, pair_list = get_demo_scene_list(shot2keyf, anno_dict) scene_dict_lst = [] + shot_num = len(shot2keyf) + shot_dict_lst = [] + for item in shot2keyf: + tmp = item.split(' ') + shot_dict_lst.append({ + 'frame': [tmp[0], tmp[1]], + 'timestamps': [tmp[-2], tmp[-1]] + }) assert len(scene_list) == len(pair_list) for scene_ind, scene_item in enumerate(scene_list): scene_dict_lst.append({ 'shot': pair_list[scene_ind], 'frame': scene_item[0], - 'timestamp': scene_item[1] + 'timestamps': scene_item[1] }) - return scene_dict_lst, scene_list + return scene_dict_lst, scene_list, shot_num, shot_dict_lst def scene2video(source_movie_fn, scene_list, thres): diff --git a/modelscope/models/cv/product_segmentation/__init__.py b/modelscope/models/cv/product_segmentation/__init__.py new file mode 100644 index 00000000..e87c8db1 --- /dev/null +++ b/modelscope/models/cv/product_segmentation/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .seg_infer import F3NetProductSegmentation + +else: + _import_structure = {'seg_infer': ['F3NetProductSegmentation']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/product_segmentation/net.py b/modelscope/models/cv/product_segmentation/net.py new file mode 100644 index 00000000..454c99d8 --- /dev/null +++ b/modelscope/models/cv/product_segmentation/net.py @@ -0,0 +1,197 @@ +# The implementation here is modified based on F3Net, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/weijun88/F3Net + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Bottleneck(nn.Module): + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + dilation=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=(3 * dilation - 1) // 2, + bias=False, + dilation=dilation) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.downsample = downsample + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x)), inplace=True) + out = F.relu(self.bn2(self.conv2(out)), inplace=True) + out = self.bn3(self.conv3(out)) + if self.downsample is not None: + x = self.downsample(x) + return F.relu(out + x, inplace=True) + + +class ResNet(nn.Module): + + def __init__(self): + super(ResNet, self).__init__() + self.inplanes = 64 + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self.make_layer(64, 3, stride=1, dilation=1) + self.layer2 = self.make_layer(128, 4, stride=2, dilation=1) + self.layer3 = self.make_layer(256, 6, stride=2, dilation=1) + self.layer4 = self.make_layer(512, 3, stride=2, dilation=1) + + def make_layer(self, planes, blocks, stride, dilation): + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * 4, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm2d(planes * 4)) + layers = [ + Bottleneck( + self.inplanes, planes, stride, downsample, dilation=dilation) + ] + self.inplanes = planes * 4 + for _ in range(1, blocks): + layers.append(Bottleneck(self.inplanes, planes, dilation=dilation)) + return nn.Sequential(*layers) + + def forward(self, x): + x = x.reshape(1, 3, 448, 448) + out1 = F.relu(self.bn1(self.conv1(x)), inplace=True) + out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1) + out2 = self.layer1(out1) + out3 = self.layer2(out2) + out4 = self.layer3(out3) + out5 = self.layer4(out4) + return out2, out3, out4, out5 + + +class CFM(nn.Module): + + def __init__(self): + super(CFM, self).__init__() + self.conv1h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn1h = nn.BatchNorm2d(64) + self.conv2h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn2h = nn.BatchNorm2d(64) + self.conv3h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn3h = nn.BatchNorm2d(64) + self.conv4h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn4h = nn.BatchNorm2d(64) + + self.conv1v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn1v = nn.BatchNorm2d(64) + self.conv2v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn2v = nn.BatchNorm2d(64) + self.conv3v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn3v = nn.BatchNorm2d(64) + self.conv4v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn4v = nn.BatchNorm2d(64) + + def forward(self, left, down): + if down.size()[2:] != left.size()[2:]: + down = F.interpolate(down, size=left.size()[2:], mode='bilinear') + out1h = F.relu(self.bn1h(self.conv1h(left)), inplace=True) + out2h = F.relu(self.bn2h(self.conv2h(out1h)), inplace=True) + out1v = F.relu(self.bn1v(self.conv1v(down)), inplace=True) + out2v = F.relu(self.bn2v(self.conv2v(out1v)), inplace=True) + fuse = out2h * out2v + out3h = F.relu(self.bn3h(self.conv3h(fuse)), inplace=True) + out1h + out4h = F.relu(self.bn4h(self.conv4h(out3h)), inplace=True) + out3v = F.relu(self.bn3v(self.conv3v(fuse)), inplace=True) + out1v + out4v = F.relu(self.bn4v(self.conv4v(out3v)), inplace=True) + return out4h, out4v + + +class Decoder(nn.Module): + + def __init__(self): + super(Decoder, self).__init__() + self.cfm45 = CFM() + self.cfm34 = CFM() + self.cfm23 = CFM() + + def forward(self, out2h, out3h, out4h, out5v, fback=None): + if fback is not None: + refine5 = F.interpolate( + fback, size=out5v.size()[2:], mode='bilinear') + refine4 = F.interpolate( + fback, size=out4h.size()[2:], mode='bilinear') + refine3 = F.interpolate( + fback, size=out3h.size()[2:], mode='bilinear') + refine2 = F.interpolate( + fback, size=out2h.size()[2:], mode='bilinear') + out5v = out5v + refine5 + out4h, out4v = self.cfm45(out4h + refine4, out5v) + out3h, out3v = self.cfm34(out3h + refine3, out4v) + out2h, pred = self.cfm23(out2h + refine2, out3v) + else: + out4h, out4v = self.cfm45(out4h, out5v) + out3h, out3v = self.cfm34(out3h, out4v) + out2h, pred = self.cfm23(out2h, out3v) + return out2h, out3h, out4h, out5v, pred + + +class F3Net(nn.Module): + + def __init__(self): + super(F3Net, self).__init__() + self.bkbone = ResNet() + self.squeeze5 = nn.Sequential( + nn.Conv2d(2048, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) + self.squeeze4 = nn.Sequential( + nn.Conv2d(1024, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) + self.squeeze3 = nn.Sequential( + nn.Conv2d(512, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) + self.squeeze2 = nn.Sequential( + nn.Conv2d(256, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) + + self.decoder1 = Decoder() + self.decoder2 = Decoder() + self.linearp1 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + self.linearp2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + + self.linearr2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + self.linearr3 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + self.linearr4 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + self.linearr5 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + + def forward(self, x, shape=None): + x = x.reshape(1, 3, 448, 448) + out2h, out3h, out4h, out5v = self.bkbone(x) + out2h, out3h, out4h, out5v = self.squeeze2(out2h), self.squeeze3( + out3h), self.squeeze4(out4h), self.squeeze5(out5v) + out2h, out3h, out4h, out5v, pred1 = self.decoder1( + out2h, out3h, out4h, out5v) + out2h, out3h, out4h, out5v, pred2 = self.decoder2( + out2h, out3h, out4h, out5v, pred1) + + shape = x.size()[2:] if shape is None else shape + pred1 = F.interpolate( + self.linearp1(pred1), size=shape, mode='bilinear') + pred2 = F.interpolate( + self.linearp2(pred2), size=shape, mode='bilinear') + + out2h = F.interpolate( + self.linearr2(out2h), size=shape, mode='bilinear') + out3h = F.interpolate( + self.linearr3(out3h), size=shape, mode='bilinear') + out4h = F.interpolate( + self.linearr4(out4h), size=shape, mode='bilinear') + out5h = F.interpolate( + self.linearr5(out5v), size=shape, mode='bilinear') + return pred1, pred2, out2h, out3h, out4h, out5h diff --git a/modelscope/models/cv/product_segmentation/seg_infer.py b/modelscope/models/cv/product_segmentation/seg_infer.py new file mode 100644 index 00000000..876fac66 --- /dev/null +++ b/modelscope/models/cv/product_segmentation/seg_infer.py @@ -0,0 +1,77 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import cv2 +import numpy as np +import torch +from PIL import Image + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .net import F3Net + +logger = get_logger() + + +def load_state_dict(model_dir, device): + _dict = torch.load( + '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + map_location=device) + state_dict = {} + for k, v in _dict.items(): + if k.startswith('module'): + k = k[7:] + state_dict[k] = v + return state_dict + + +@MODELS.register_module( + Tasks.product_segmentation, module_name=Models.product_segmentation) +class F3NetForProductSegmentation(TorchModel): + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + + self.model = F3Net() + if torch.cuda.is_available(): + self.device = 'cuda' + logger.info('Use GPU') + else: + self.device = 'cpu' + logger.info('Use CPU') + + self.params = load_state_dict(model_dir, self.device) + self.model.load_state_dict(self.params) + self.model.to(self.device) + self.model.eval() + self.model.to(self.device) + + def forward(self, x): + pred_result = self.model(x) + return pred_result + + +mean, std = np.array([[[124.55, 118.90, + 102.94]]]), np.array([[[56.77, 55.97, 57.50]]]) + + +def inference(model, device, input_path): + img = Image.open(input_path) + img = np.array(img.convert('RGB')).astype(np.float32) + img = (img - mean) / std + img = cv2.resize(img, dsize=(448, 448), interpolation=cv2.INTER_LINEAR) + img = torch.from_numpy(img) + img = img.permute(2, 0, 1) + img = img.to(device).float() + outputs = model(img) + out = outputs[0] + pred = (torch.sigmoid(out[0, 0]) * 255).cpu().numpy() + pred[pred < 20] = 0 + pred = pred[:, :, np.newaxis] + pred = np.round(pred) + logger.info('Inference Done') + return pred diff --git a/modelscope/models/cv/skin_retouching/detection_model/detection_module.py b/modelscope/models/cv/skin_retouching/detection_model/detection_module.py index f89ce37b..5db9c44c 100644 --- a/modelscope/models/cv/skin_retouching/detection_model/detection_module.py +++ b/modelscope/models/cv/skin_retouching/detection_model/detection_module.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn as nn diff --git a/modelscope/models/cv/skin_retouching/detection_model/detection_unet_in.py b/modelscope/models/cv/skin_retouching/detection_model/detection_unet_in.py index b48f6e5f..c0be1a52 100644 --- a/modelscope/models/cv/skin_retouching/detection_model/detection_unet_in.py +++ b/modelscope/models/cv/skin_retouching/detection_model/detection_unet_in.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn as nn import torch.nn.functional as F diff --git a/modelscope/models/cv/skin_retouching/inpainting_model/gconv.py b/modelscope/models/cv/skin_retouching/inpainting_model/gconv.py index e0910d2c..8b3eb2fc 100644 --- a/modelscope/models/cv/skin_retouching/inpainting_model/gconv.py +++ b/modelscope/models/cv/skin_retouching/inpainting_model/gconv.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn as nn diff --git a/modelscope/models/cv/skin_retouching/inpainting_model/inpainting_unet.py b/modelscope/models/cv/skin_retouching/inpainting_model/inpainting_unet.py index 09cea1fc..dd220dd6 100644 --- a/modelscope/models/cv/skin_retouching/inpainting_model/inpainting_unet.py +++ b/modelscope/models/cv/skin_retouching/inpainting_model/inpainting_unet.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn as nn import torch.nn.functional as F diff --git a/modelscope/models/cv/skin_retouching/unet_deploy.py b/modelscope/models/cv/skin_retouching/unet_deploy.py index cb37b04c..0ff75b85 100755 --- a/modelscope/models/cv/skin_retouching/unet_deploy.py +++ b/modelscope/models/cv/skin_retouching/unet_deploy.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import warnings import torch diff --git a/modelscope/models/cv/skin_retouching/utils.py b/modelscope/models/cv/skin_retouching/utils.py index 12653f41..eb0da6b9 100644 --- a/modelscope/models/cv/skin_retouching/utils.py +++ b/modelscope/models/cv/skin_retouching/utils.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import time from typing import Dict, List, Optional, Tuple, Union diff --git a/modelscope/models/cv/skin_retouching/weights_init.py b/modelscope/models/cv/skin_retouching/weights_init.py index efd24843..ae62d4a4 100644 --- a/modelscope/models/cv/skin_retouching/weights_init.py +++ b/modelscope/models/cv/skin_retouching/weights_init.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn as nn diff --git a/modelscope/models/cv/super_resolution/arch_util.py b/modelscope/models/cv/super_resolution/arch_util.py index 4b87c877..99711a11 100644 --- a/modelscope/models/cv/super_resolution/arch_util.py +++ b/modelscope/models/cv/super_resolution/arch_util.py @@ -1,3 +1,5 @@ +# The implementation is adopted from BasicSR, made public available under the Apache 2.0 License +# at https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/archs/arch_util.py import collections.abc import math import warnings diff --git a/modelscope/models/cv/super_resolution/rrdbnet_arch.py b/modelscope/models/cv/super_resolution/rrdbnet_arch.py index 44947de1..8c84f796 100644 --- a/modelscope/models/cv/super_resolution/rrdbnet_arch.py +++ b/modelscope/models/cv/super_resolution/rrdbnet_arch.py @@ -1,3 +1,5 @@ +# The implementation is adopted from BasicSR, made public available under the Apache 2.0 License +# at https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/archs/rrdbnet_arch.py import torch from torch import nn as nn from torch.nn import functional as F diff --git a/modelscope/models/multi_modal/clip/__init__.py b/modelscope/models/multi_modal/clip/__init__.py index 3fd492b9..e2e925ce 100644 --- a/modelscope/models/multi_modal/clip/__init__.py +++ b/modelscope/models/multi_modal/clip/__init__.py @@ -1 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + from .model import CLIPForMultiModalEmbedding diff --git a/modelscope/models/multi_modal/clip/model.py b/modelscope/models/multi_modal/clip/model.py index 2fb0d7e3..92d9e11a 100644 --- a/modelscope/models/multi_modal/clip/model.py +++ b/modelscope/models/multi_modal/clip/model.py @@ -1,3 +1,18 @@ +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from collections import OrderedDict from typing import Any, Dict, Iterable, List, Tuple, Union diff --git a/modelscope/models/multi_modal/gemm/gemm_base.py b/modelscope/models/multi_modal/gemm/gemm_base.py index 09ef2480..806c469c 100644 --- a/modelscope/models/multi_modal/gemm/gemm_base.py +++ b/modelscope/models/multi_modal/gemm/gemm_base.py @@ -543,6 +543,7 @@ class GEMMModel(nn.Module): img_feature, text_feature, caption = None, None, None if captioning and image is not None: img_feature, caption = self.model.image_to_text(image) + img_feature = self.parse_feat(img_feature) elif image is not None: img_feature = self.parse_feat(self.model.encode_image(image)) if text is not None: diff --git a/modelscope/models/multi_modal/gemm/gemm_model.py b/modelscope/models/multi_modal/gemm/gemm_model.py index 55b211c0..c90b35d4 100644 --- a/modelscope/models/multi_modal/gemm/gemm_model.py +++ b/modelscope/models/multi_modal/gemm/gemm_model.py @@ -67,7 +67,7 @@ class GEMMForMultiModalEmbedding(TorchModel): return img_tensor def parse_text(self, text_str): - if text_str is None: + if text_str is None or len(text_str) == 0: return None if isinstance(text_str, str): text_ids_tensor = self.gemm_model.tokenize(text_str) @@ -79,9 +79,12 @@ class GEMMForMultiModalEmbedding(TorchModel): return text_ids_tensor.view(1, -1) def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - image = self.parse_image(input.get('image', input.get('img', None))) - text = self.parse_text(input.get('text', input.get('txt', None))) - captioning = input.get('captioning', False) is True + image_input = input.get('image', input.get('img', None)) + text_input = input.get('text', input.get('txt', None)) + captioning_input = input.get('captioning', None) + image = self.parse_image(image_input) + text = self.parse_text(text_input) + captioning = captioning_input is True or text_input == '' out = self.gemm_model(image, text, captioning) output = { OutputKeys.IMG_EMBEDDING: out.get('image_feature', None), diff --git a/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py b/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py index 5e8e2e7a..2a72985f 100644 --- a/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py +++ b/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py @@ -1,4 +1,4 @@ -# The implementation is adopated from the CLIP4Clip implementation, +# The implementation is adopted from the CLIP4Clip implementation, # made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip import random diff --git a/modelscope/models/multi_modal/mmr/models/dynamic_inverted_softmax.py b/modelscope/models/multi_modal/mmr/models/dynamic_inverted_softmax.py index 253a847c..c2d96275 100644 --- a/modelscope/models/multi_modal/mmr/models/dynamic_inverted_softmax.py +++ b/modelscope/models/multi_modal/mmr/models/dynamic_inverted_softmax.py @@ -1,4 +1,4 @@ -# The implementation is adopated from the CLIP4Clip implementation, +# The implementation is adopted from the CLIP4Clip implementation, # made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip import numpy as np diff --git a/modelscope/models/multi_modal/mmr/models/tokenization_clip.py b/modelscope/models/multi_modal/mmr/models/tokenization_clip.py index 4e2c9b15..97ee7156 100644 --- a/modelscope/models/multi_modal/mmr/models/tokenization_clip.py +++ b/modelscope/models/multi_modal/mmr/models/tokenization_clip.py @@ -1,4 +1,4 @@ -# The implementation is adopated from the CLIP4Clip implementation, +# The implementation is adopted from the CLIP4Clip implementation, # made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip import gzip diff --git a/modelscope/models/multi_modal/ofa/__init__.py b/modelscope/models/multi_modal/ofa/__init__.py index 16de7fff..3e8e59f4 100644 --- a/modelscope/models/multi_modal/ofa/__init__.py +++ b/modelscope/models/multi_modal/ofa/__init__.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + from .modeling_ofa import OFADecoder, OFAEncoder, OFAModel, OFAPreTrainedModel from .tokenization_ofa import OFATokenizer, OFATokenizerZH from .tokenization_ofa_fast import OFATokenizerFast, OFATokenizerZHFast diff --git a/modelscope/models/multi_modal/ofa/resnet.py b/modelscope/models/multi_modal/ofa/resnet.py index de6444ab..aad0f002 100644 --- a/modelscope/models/multi_modal/ofa/resnet.py +++ b/modelscope/models/multi_modal/ofa/resnet.py @@ -1,3 +1,17 @@ +# Copyright 2022 OFA-Sys Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch import torch.nn as nn diff --git a/modelscope/models/multi_modal/ofa/utils/__init__.py b/modelscope/models/multi_modal/ofa/utils/__init__.py index f515818c..76b03eeb 100644 --- a/modelscope/models/multi_modal/ofa/utils/__init__.py +++ b/modelscope/models/multi_modal/ofa/utils/__init__.py @@ -1 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from .constant import OFA_TASK_KEY_MAPPING diff --git a/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py b/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py index b942e3fa..8110a0f7 100644 --- a/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py +++ b/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Any, Dict diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 357afd07..717ff4dd 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -21,6 +21,7 @@ class OutputKeys(object): POLYGONS = 'polygons' OUTPUT = 'output' OUTPUT_IMG = 'output_img' + OUTPUT_VIDEO = 'output_video' OUTPUT_PCM = 'output_pcm' IMG_EMBEDDING = 'img_embedding' SPO_LIST = 'spo_list' @@ -37,8 +38,10 @@ class OutputKeys(object): KWS_LIST = 'kws_list' HISTORY = 'history' TIMESTAMPS = 'timestamps' - SPLIT_VIDEO_NUM = 'split_video_num' - SPLIT_META_LIST = 'split_meta_list' + SHOT_NUM = 'shot_num' + SCENE_NUM = 'scene_num' + SCENE_META_LIST = 'scene_meta_list' + SHOT_META_LIST = 'shot_meta_list' TASK_OUTPUTS = { @@ -218,13 +221,21 @@ TASK_OUTPUTS = { # 3D human body keypoints detection result for single sample # { - # "poses": [ - # [[x, y, z]*17], - # [[x, y, z]*17], - # [[x, y, z]*17] - # ] + # "poses": [ # 3d pose coordinate in camera coordinate + # [[x, y, z]*17], # joints of per image + # [[x, y, z]*17], + # ... + # ], + # "timestamps": [ # timestamps of all frames + # "00:00:0.230", + # "00:00:0.560", + # "00:00:0.690", + # ], + # "output_video": "path_to_rendered_video" , this is optional + # and is only avaialbe when the "render" option is enabled. # } - Tasks.body_3d_keypoints: [OutputKeys.POSES], + Tasks.body_3d_keypoints: + [OutputKeys.POSES, OutputKeys.TIMESTAMPS, OutputKeys.OUTPUT_VIDEO], # 2D hand keypoints result for single sample # { @@ -300,19 +311,30 @@ TASK_OUTPUTS = { Tasks.shop_segmentation: [OutputKeys.MASKS], # movide scene segmentation result for a single video # { - # "split_video_num":3, - # "split_meta_list": + # "shot_num":15, + # "shot_meta_list": + # [ + # { + # "frame": [start_frame, end_frame], + # "timestamps": [start_timestamp, end_timestamp] # ['00:00:01.133', '00:00:02.245'] + # + # } + # ] + # "scene_num":3, + # "scene_meta_list": # [ # { # "shot": [0,1,2], # "frame": [start_frame, end_frame], - # "timestamp": [start_timestamp, end_timestamp] # ['00:00:01.133', '00:00:02.245'] + # "timestamps": [start_timestamp, end_timestamp] # ['00:00:01.133', '00:00:02.245'] # } # ] # # } - Tasks.movie_scene_segmentation: - [OutputKeys.SPLIT_VIDEO_NUM, OutputKeys.SPLIT_META_LIST], + Tasks.movie_scene_segmentation: [ + OutputKeys.SHOT_NUM, OutputKeys.SHOT_META_LIST, OutputKeys.SCENE_NUM, + OutputKeys.SCENE_META_LIST + ], # ============ nlp tasks =================== @@ -649,8 +671,28 @@ TASK_OUTPUTS = { # 'output': ['Done' / 'Decode_Error'] # } Tasks.video_inpainting: [OutputKeys.OUTPUT], + # { # 'output': ['bixin'] # } - Tasks.hand_static: [OutputKeys.OUTPUT] + Tasks.hand_static: [OutputKeys.OUTPUT], + + # 'output': [ + # [2, 75, 287, 240, 510, 0.8335018754005432], + # [1, 127, 83, 332, 366, 0.9175254702568054], + # [0, 0, 0, 367, 639, 0.9693422317504883]] + # } + Tasks.face_human_hand_detection: [OutputKeys.OUTPUT], + + # { + # {'output': 'Happiness', 'boxes': (203, 104, 663, 564)} + # } + Tasks.face_emotion: [OutputKeys.OUTPUT, OutputKeys.BOXES], + + # { + # "masks": [ + # np.array # 2D array containing only 0, 255 + # ] + # } + Tasks.product_segmentation: [OutputKeys.MASKS], } diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py new file mode 100644 index 00000000..de9814a7 --- /dev/null +++ b/modelscope/pipeline_inputs.py @@ -0,0 +1,236 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import cv2 +import numpy as np +from PIL import Image + +from modelscope.models.base.base_head import Input +from modelscope.utils.constant import Tasks + + +class InputKeys(object): + IMAGE = 'image' + TEXT = 'text' + VIDEO = 'video' + + +class InputType(object): + IMAGE = 'image' + TEXT = 'text' + AUDIO = 'audio' + VIDEO = 'video' + BOX = 'box' + DICT = 'dict' + LIST = 'list' + INT = 'int' + + +INPUT_TYPE = { + InputType.IMAGE: (str, np.ndarray, Image.Image), + InputType.TEXT: str, + InputType.AUDIO: (str, np.ndarray), + InputType.VIDEO: (str, np.ndarray, cv2.VideoCapture), + InputType.BOX: (list, np.ndarray), + InputType.DICT: (dict, type(None)), + InputType.LIST: (list, type(None)), + InputType.INT: int, +} + + +def check_input_type(input_type, input): + expected_type = INPUT_TYPE[input_type] + assert isinstance(input, expected_type), \ + f'invalid input type for {input_type}, expected {expected_type} but got {type(input)}\n {input}' + + +TASK_INPUTS = { + # if task input is single var, value is InputType + # if task input is a tuple, value is tuple of InputType + # if task input is a dict, value is a dict of InputType, where key + # equals the one needed in pipeline input dict + # if task input is a list, value is a set of input format, in which + # each elements corresponds to one input format as described above. + # ============ vision tasks =================== + Tasks.ocr_detection: + InputType.IMAGE, + Tasks.ocr_recognition: + InputType.IMAGE, + Tasks.face_2d_keypoints: + InputType.IMAGE, + Tasks.face_detection: + InputType.IMAGE, + Tasks.facial_expression_recognition: + InputType.IMAGE, + Tasks.face_recognition: + InputType.IMAGE, + Tasks.human_detection: + InputType.IMAGE, + Tasks.face_image_generation: + InputType.INT, + Tasks.image_classification: + InputType.IMAGE, + Tasks.image_object_detection: + InputType.IMAGE, + Tasks.image_segmentation: + InputType.IMAGE, + Tasks.portrait_matting: + InputType.IMAGE, + + # image editing task result for a single image + Tasks.skin_retouching: + InputType.IMAGE, + Tasks.image_super_resolution: + InputType.IMAGE, + Tasks.image_colorization: + InputType.IMAGE, + Tasks.image_color_enhancement: + InputType.IMAGE, + Tasks.image_denoising: + InputType.IMAGE, + Tasks.image_portrait_enhancement: + InputType.IMAGE, + Tasks.crowd_counting: + InputType.IMAGE, + + # image generation task result for a single image + Tasks.image_to_image_generation: + InputType.IMAGE, + Tasks.image_to_image_translation: + InputType.IMAGE, + Tasks.image_style_transfer: + InputType.IMAGE, + Tasks.image_portrait_stylization: + InputType.IMAGE, + Tasks.live_category: + InputType.VIDEO, + Tasks.action_recognition: + InputType.VIDEO, + Tasks.body_2d_keypoints: + InputType.IMAGE, + Tasks.body_3d_keypoints: + InputType.VIDEO, + Tasks.hand_2d_keypoints: + InputType.IMAGE, + Tasks.video_single_object_tracking: (InputType.VIDEO, InputType.BOX), + Tasks.video_category: + InputType.VIDEO, + Tasks.product_retrieval_embedding: + InputType.IMAGE, + Tasks.video_embedding: + InputType.VIDEO, + Tasks.virtual_try_on: (InputType.IMAGE, InputType.IMAGE, InputType.IMAGE), + Tasks.text_driven_segmentation: { + InputKeys.IMAGE: InputType.IMAGE, + InputKeys.TEXT: InputType.TEXT + }, + Tasks.shop_segmentation: + InputType.IMAGE, + Tasks.movie_scene_segmentation: + InputType.VIDEO, + + # ============ nlp tasks =================== + Tasks.text_classification: [ + InputType.TEXT, + (InputType.TEXT, InputType.TEXT), + { + 'text': InputType.TEXT, + 'text2': InputType.TEXT + }, + ], + Tasks.sentence_similarity: (InputType.TEXT, InputType.TEXT), + Tasks.nli: (InputType.TEXT, InputType.TEXT), + Tasks.sentiment_classification: + InputType.TEXT, + Tasks.zero_shot_classification: + InputType.TEXT, + Tasks.relation_extraction: + InputType.TEXT, + Tasks.translation: + InputType.TEXT, + Tasks.word_segmentation: + InputType.TEXT, + Tasks.part_of_speech: + InputType.TEXT, + Tasks.named_entity_recognition: + InputType.TEXT, + Tasks.text_error_correction: + InputType.TEXT, + Tasks.sentence_embedding: { + 'source_sentence': InputType.LIST, + 'sentences_to_compare': InputType.LIST, + }, + Tasks.passage_ranking: (InputType.TEXT, InputType.TEXT), + Tasks.text_generation: + InputType.TEXT, + Tasks.fill_mask: + InputType.TEXT, + Tasks.task_oriented_conversation: { + 'user_input': InputType.TEXT, + 'history': InputType.DICT, + }, + Tasks.table_question_answering: { + 'question': InputType.TEXT, + 'history_sql': InputType.DICT, + }, + Tasks.faq_question_answering: { + 'query_set': InputType.LIST, + 'support_set': InputType.LIST, + }, + + # ============ audio tasks =================== + Tasks.auto_speech_recognition: + InputType.AUDIO, + Tasks.speech_signal_process: + InputType.AUDIO, + Tasks.acoustic_echo_cancellation: { + 'nearend_mic': InputType.AUDIO, + 'farend_speech': InputType.AUDIO + }, + Tasks.acoustic_noise_suppression: + InputType.AUDIO, + Tasks.text_to_speech: + InputType.TEXT, + Tasks.keyword_spotting: + InputType.AUDIO, + + # ============ multi-modal tasks =================== + Tasks.image_captioning: + InputType.IMAGE, + Tasks.visual_grounding: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.text_to_image_synthesis: { + 'text': InputType.TEXT, + }, + Tasks.multi_modal_embedding: { + 'img': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.generative_multi_modal_embedding: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.multi_modal_similarity: { + 'img': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.visual_question_answering: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.visual_entailment: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT, + 'text2': InputType.TEXT, + }, + Tasks.action_detection: + InputType.VIDEO, + Tasks.image_reid_person: + InputType.IMAGE, + Tasks.video_inpainting: { + 'video_input_path': InputType.TEXT, + 'video_output_path': InputType.TEXT, + 'mask_path': InputType.TEXT, + } +} diff --git a/modelscope/pipelines/audio/asr_inference_pipeline.py b/modelscope/pipelines/audio/asr_inference_pipeline.py index 282d1184..4e8b658d 100644 --- a/modelscope/pipelines/audio/asr_inference_pipeline.py +++ b/modelscope/pipelines/audio/asr_inference_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, List, Sequence, Tuple, Union import yaml diff --git a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py index 866b8d0b..450a12bb 100644 --- a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py +++ b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os from typing import Any, Dict, List, Union diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index c5db2b57..5732a9d7 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -13,6 +13,7 @@ import numpy as np from modelscope.models.base import Model from modelscope.msdatasets import MsDataset from modelscope.outputs import TASK_OUTPUTS +from modelscope.pipeline_inputs import TASK_INPUTS, check_input_type from modelscope.preprocessors import Preprocessor from modelscope.utils.config import Config from modelscope.utils.constant import Frameworks, ModelFile @@ -210,7 +211,7 @@ class Pipeline(ABC): preprocess_params = kwargs.get('preprocess_params', {}) forward_params = kwargs.get('forward_params', {}) postprocess_params = kwargs.get('postprocess_params', {}) - + self._check_input(input) out = self.preprocess(input, **preprocess_params) with device_placement(self.framework, self.device_name): if self.framework == Frameworks.torch: @@ -225,6 +226,42 @@ class Pipeline(ABC): self._check_output(out) return out + def _check_input(self, input): + task_name = self.group_key + if task_name in TASK_INPUTS: + input_type = TASK_INPUTS[task_name] + + # if multiple input formats are defined, we first + # found the one that match input data and check + if isinstance(input_type, list): + matched_type = None + for t in input_type: + if type(t) == type(input): + matched_type = t + break + if matched_type is None: + err_msg = 'input data format for current pipeline should be one of following: \n' + for t in input_type: + err_msg += f'{t}\n' + raise ValueError(err_msg) + else: + input_type = matched_type + + if isinstance(input_type, str): + check_input_type(input_type, input) + elif isinstance(input_type, tuple): + for t, input_ele in zip(input_type, input): + check_input_type(t, input_ele) + elif isinstance(input_type, dict): + for k in input_type.keys(): + # allow single input for multi-modal models + if k in input: + check_input_type(input_type[k], input[k]) + else: + raise ValueError(f'invalid input_type definition {input_type}') + else: + logger.warning(f'task {task_name} input definition is missing') + def _check_output(self, input): # this attribute is dynamically attached by registry # when cls is registered in registry using task name diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 4f6873b0..7fa66b5f 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -183,6 +183,12 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_video-inpainting'), Tasks.hand_static: (Pipelines.hand_static, 'damo/cv_mobileface_hand-static'), + Tasks.face_human_hand_detection: + (Pipelines.face_human_hand_detection, + 'damo/cv_nanodet_face-human-hand-detection'), + Tasks.face_emotion: (Pipelines.face_emotion, 'damo/cv_face-emotion'), + Tasks.product_segmentation: (Pipelines.product_segmentation, + 'damo/cv_F3Net_product-segmentation'), } diff --git a/modelscope/pipelines/cv/action_recognition_pipeline.py b/modelscope/pipelines/cv/action_recognition_pipeline.py index 7f1a46b2..993a32f0 100644 --- a/modelscope/pipelines/cv/action_recognition_pipeline.py +++ b/modelscope/pipelines/cv/action_recognition_pipeline.py @@ -7,7 +7,8 @@ from typing import Any, Dict import torch from modelscope.metainfo import Pipelines -from modelscope.models.cv.action_recognition import BaseVideoModel +from modelscope.models.cv.action_recognition import (BaseVideoModel, + PatchShiftTransformer) from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES @@ -69,3 +70,54 @@ class ActionRecognitionPipeline(Pipeline): def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs + + +@PIPELINES.register_module( + Tasks.action_recognition, module_name=Pipelines.pst_action_recognition) +class PSTActionRecognitionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a PST action recognition pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {model_path}') + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + self.infer_model = PatchShiftTransformer(model).to(self.device) + self.infer_model.eval() + self.infer_model.load_state_dict( + torch.load(model_path, map_location=self.device)['state_dict']) + self.label_mapping = self.cfg.label_mapping + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + video_input_data = ReadVideoData(self.cfg, input).to(self.device) + else: + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + result = {'video_data': video_input_data} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + pred = self.perform_inference(input['video_data']) + output_label = self.label_mapping[str(pred)] + return {OutputKeys.LABELS: output_label} + + @torch.no_grad() + def perform_inference(self, data, max_bsz=4): + iter_num = math.ceil(data.size(0) / max_bsz) + preds_list = [] + for i in range(iter_num): + preds_list.append( + self.infer_model(data[i * max_bsz:(i + 1) * max_bsz])) + pred = torch.cat(preds_list, dim=0) + return pred.mean(dim=0).argmax().item() + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py index c6a05195..d6afbae4 100644 --- a/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py +++ b/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os.path as osp from typing import Any, Dict, List, Union diff --git a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py index e9e4e9e8..b0faa1e0 100644 --- a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py +++ b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py @@ -1,10 +1,19 @@ -import os +# Copyright (c) Alibaba, Inc. and its affiliates. + +import datetime import os.path as osp +import tempfile from typing import Any, Dict, List, Union import cv2 +import matplotlib +import matplotlib.pyplot as plt +import mpl_toolkits.mplot3d.axes3d as p3 import numpy as np import torch +from matplotlib import animation +from matplotlib.animation import writers +from matplotlib.ticker import MultipleLocator from modelscope.metainfo import Pipelines from modelscope.models.cv.body_3d_keypoints.body_3d_pose import ( @@ -16,6 +25,8 @@ from modelscope.pipelines.builder import PIPELINES from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger +matplotlib.use('Agg') + logger = get_logger() @@ -121,7 +132,8 @@ class Body3DKeypointsPipeline(Pipeline): device='gpu' if torch.cuda.is_available() else 'cpu') def preprocess(self, input: Input) -> Dict[str, Any]: - video_frames = self.read_video_frames(input) + video_url = input + video_frames = self.read_video_frames(video_url) if 0 == len(video_frames): res = {'success': False, 'msg': 'get video frame failed.'} return res @@ -168,13 +180,25 @@ class Body3DKeypointsPipeline(Pipeline): return res def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: - res = {OutputKeys.POSES: []} + res = {OutputKeys.POSES: [], OutputKeys.TIMESTAMPS: []} if not input['success']: pass else: poses = input[KeypointsTypes.POSES_CAMERA] - res = {OutputKeys.POSES: poses.data.cpu().numpy()} + pred_3d_pose = poses.data.cpu().numpy()[ + 0] # [frame_num, joint_num, joint_dim] + + output_video_path = kwargs.get('output_video', None) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile( + suffix='.mp4').name + if 'render' in self.keypoint_model_3d.cfg.keys(): + self.render_prediction(pred_3d_pose, output_video_path) + res[OutputKeys.OUTPUT_VIDEO] = output_video_path + + res[OutputKeys.POSES] = pred_3d_pose + res[OutputKeys.TIMESTAMPS] = self.timestamps return res def read_video_frames(self, video_url: Union[str, cv2.VideoCapture]): @@ -189,7 +213,15 @@ class Body3DKeypointsPipeline(Pipeline): Returns: [nd.array]: List of video frames. """ + + def timestamp_format(seconds): + m, s = divmod(seconds, 60) + h, m = divmod(m, 60) + time = '%02d:%02d:%06.3f' % (h, m, s) + return time + frames = [] + self.timestamps = [] # for video render if isinstance(video_url, str): cap = cv2.VideoCapture(video_url) if not cap.isOpened(): @@ -199,15 +231,131 @@ class Body3DKeypointsPipeline(Pipeline): else: cap = video_url + self.fps = cap.get(cv2.CAP_PROP_FPS) + if self.fps is None or self.fps <= 0: + raise Exception('modelscope error: %s cannot get video fps info.' % + (video_url)) + max_frame_num = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME frame_idx = 0 while True: ret, frame = cap.read() if not ret: break + self.timestamps.append( + timestamp_format(seconds=frame_idx / self.fps)) frame_idx += 1 frames.append(frame) if frame_idx >= max_frame_num: break cap.release() return frames + + def render_prediction(self, pose3d_cam_rr, output_video_path): + """render predict result 3d poses. + + Args: + pose3d_cam_rr (nd.array): [frame_num, joint_num, joint_dim], 3d pose joints + output_video_path (str): output path for video + Returns: + """ + frame_num = pose3d_cam_rr.shape[0] + + left_points = [11, 12, 13, 4, 5, 6] # joints of left body + edges = [[0, 1], [0, 4], [0, 7], [1, 2], [4, 5], [5, 6], [2, + 3], [7, 8], + [8, 9], [8, 11], [8, 14], [14, 15], [15, 16], [11, 12], + [12, 13], [9, 10]] # connection between joints + + fig = plt.figure() + ax = p3.Axes3D(fig) + x_major_locator = MultipleLocator(0.5) + + ax.xaxis.set_major_locator(x_major_locator) + ax.yaxis.set_major_locator(x_major_locator) + ax.zaxis.set_major_locator(x_major_locator) + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.set_xlim(-1, 1) + ax.set_ylim(-1, 1) + ax.set_zlim(-1, 1) + # view direction + azim = self.keypoint_model_3d.cfg.render.azim + elev = self.keypoint_model_3d.cfg.render.elev + ax.view_init(elev, azim) + + # init plot, essentially + x = pose3d_cam_rr[0, :, 0] + y = pose3d_cam_rr[0, :, 1] + z = pose3d_cam_rr[0, :, 2] + points, = ax.plot(x, y, z, 'r.') + + def renderBones(xs, ys, zs): + """render bones in skeleton + + Args: + xs (nd.array): [joint_num, joint_channel] + ys (nd.array): [joint_num, joint_channel] + zs (nd.array): [joint_num, joint_channel] + """ + bones = {} + for idx, edge in enumerate(edges): + index1, index2 = edge[0], edge[1] + if index1 in left_points: + edge_color = 'red' + else: + edge_color = 'blue' + connect = ax.plot([xs[index1], xs[index2]], + [ys[index1], ys[index2]], + [zs[index1], zs[index2]], + linewidth=2, + color=edge_color) # plot edge + bones[idx] = connect[0] + return bones + + bones = renderBones(x, y, z) + + def update(frame_idx, points, bones): + """update animation + + Args: + frame_idx (int): frame index + points (mpl_toolkits.mplot3d.art3d.Line3D): skeleton points ploter + bones (dict[int, mpl_toolkits.mplot3d.art3d.Line3D]): connection ploter + + Returns: + tuple: points and bones ploter + """ + xs = pose3d_cam_rr[frame_idx, :, 0] + ys = pose3d_cam_rr[frame_idx, :, 1] + zs = pose3d_cam_rr[frame_idx, :, 2] + + # update bones + for idx, edge in enumerate(edges): + index1, index2 = edge[0], edge[1] + x1x2 = (xs[index1], xs[index2]) + y1y2 = (ys[index1], ys[index2]) + z1z2 = (zs[index1], zs[index2]) + bones[idx].set_xdata(x1x2) + bones[idx].set_ydata(y1y2) + bones[idx].set_3d_properties(z1z2, 'z') + + # update joints + points.set_data(xs, ys) + points.set_3d_properties(zs, 'z') + if 0 == frame_idx / 100: + logger.info(f'rendering {frame_idx}/{frame_num}') + return points, bones + + ani = animation.FuncAnimation( + fig=fig, + func=update, + frames=frame_num, + interval=self.fps, + fargs=(points, bones)) + + # save mp4 + Writer = writers['ffmpeg'] + writer = Writer(fps=self.fps, metadata={}, bitrate=4096) + ani.save(output_video_path, writer=writer) diff --git a/modelscope/pipelines/cv/crowd_counting_pipeline.py b/modelscope/pipelines/cv/crowd_counting_pipeline.py index 3143825b..93fffdf2 100644 --- a/modelscope/pipelines/cv/crowd_counting_pipeline.py +++ b/modelscope/pipelines/cv/crowd_counting_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import math from typing import Any, Dict diff --git a/modelscope/pipelines/cv/face_emotion_pipeline.py b/modelscope/pipelines/cv/face_emotion_pipeline.py new file mode 100644 index 00000000..249493b6 --- /dev/null +++ b/modelscope/pipelines/cv/face_emotion_pipeline.py @@ -0,0 +1,39 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_emotion import emotion_infer +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_emotion, module_name=Pipelines.face_emotion) +class FaceEmotionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create face emotion pipeline for prediction + Args: + model: model id on modelscope hub. + """ + + super().__init__(model=model, **kwargs) + self.face_model = model + '/' + ModelFile.TF_GRAPH_FILE + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + result, bbox = emotion_infer.inference(input['img_path'], self.model, + self.face_model) + return {OutputKeys.OUTPUT: result, OutputKeys.BOXES: bbox} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/face_human_hand_detection_pipeline.py b/modelscope/pipelines/cv/face_human_hand_detection_pipeline.py new file mode 100644 index 00000000..d9f214c9 --- /dev/null +++ b/modelscope/pipelines/cv/face_human_hand_detection_pipeline.py @@ -0,0 +1,42 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_human_hand_detection import det_infer +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_human_hand_detection, + module_name=Pipelines.face_human_hand_detection) +class NanoDettForFaceHumanHandDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create face-human-hand detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + + super().__init__(model=model, **kwargs) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + result = det_infer.inference(self.model, self.device, + input['input_path']) + logger.info(result) + return {OutputKeys.OUTPUT: result} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/face_image_generation_pipeline.py b/modelscope/pipelines/cv/face_image_generation_pipeline.py index 405c9a4b..f00d639e 100644 --- a/modelscope/pipelines/cv/face_image_generation_pipeline.py +++ b/modelscope/pipelines/cv/face_image_generation_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os from typing import Any, Dict diff --git a/modelscope/pipelines/cv/facial_expression_recognition_pipeline.py b/modelscope/pipelines/cv/facial_expression_recognition_pipeline.py index c5577dcf..1b1f13d1 100644 --- a/modelscope/pipelines/cv/facial_expression_recognition_pipeline.py +++ b/modelscope/pipelines/cv/facial_expression_recognition_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp from typing import Any, Dict diff --git a/modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py b/modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py index 5e4cd4c6..21af2f75 100644 --- a/modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py +++ b/modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import math import os.path as osp from typing import Any, Dict diff --git a/modelscope/pipelines/cv/image_cartoon_pipeline.py b/modelscope/pipelines/cv/image_cartoon_pipeline.py index 787aa06d..8606915c 100644 --- a/modelscope/pipelines/cv/image_cartoon_pipeline.py +++ b/modelscope/pipelines/cv/image_cartoon_pipeline.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Any, Dict diff --git a/modelscope/pipelines/cv/image_color_enhance_pipeline.py b/modelscope/pipelines/cv/image_color_enhance_pipeline.py index 40777d60..d21d879c 100644 --- a/modelscope/pipelines/cv/image_color_enhance_pipeline.py +++ b/modelscope/pipelines/cv/image_color_enhance_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, Optional, Union import torch diff --git a/modelscope/pipelines/cv/image_colorization_pipeline.py b/modelscope/pipelines/cv/image_colorization_pipeline.py index 0fea729d..cd385024 100644 --- a/modelscope/pipelines/cv/image_colorization_pipeline.py +++ b/modelscope/pipelines/cv/image_colorization_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict import cv2 diff --git a/modelscope/pipelines/cv/image_denoise_pipeline.py b/modelscope/pipelines/cv/image_denoise_pipeline.py index 64aa3bc9..a11abf36 100644 --- a/modelscope/pipelines/cv/image_denoise_pipeline.py +++ b/modelscope/pipelines/cv/image_denoise_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, Optional, Union import torch diff --git a/modelscope/pipelines/cv/image_detection_pipeline.py b/modelscope/pipelines/cv/image_detection_pipeline.py index 8df10d45..f5554ca2 100644 --- a/modelscope/pipelines/cv/image_detection_pipeline.py +++ b/modelscope/pipelines/cv/image_detection_pipeline.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + from typing import Any, Dict import numpy as np diff --git a/modelscope/pipelines/cv/image_matting_pipeline.py b/modelscope/pipelines/cv/image_matting_pipeline.py index d7b7fc3c..fb5d8f8b 100644 --- a/modelscope/pipelines/cv/image_matting_pipeline.py +++ b/modelscope/pipelines/cv/image_matting_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp from typing import Any, Dict diff --git a/modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py b/modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py index 87e692e8..3eec6526 100644 --- a/modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py +++ b/modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import math from typing import Any, Dict diff --git a/modelscope/pipelines/cv/image_salient_detection_pipeline.py b/modelscope/pipelines/cv/image_salient_detection_pipeline.py index 3b145cf0..4a3eaa65 100644 --- a/modelscope/pipelines/cv/image_salient_detection_pipeline.py +++ b/modelscope/pipelines/cv/image_salient_detection_pipeline.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + from typing import Any, Dict from modelscope.metainfo import Pipelines diff --git a/modelscope/pipelines/cv/image_style_transfer_pipeline.py b/modelscope/pipelines/cv/image_style_transfer_pipeline.py index 64e67115..e5fd0d48 100644 --- a/modelscope/pipelines/cv/image_style_transfer_pipeline.py +++ b/modelscope/pipelines/cv/image_style_transfer_pipeline.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os.path as osp from typing import Any, Dict diff --git a/modelscope/pipelines/cv/image_super_resolution_pipeline.py b/modelscope/pipelines/cv/image_super_resolution_pipeline.py index 657acc41..ca8f3209 100644 --- a/modelscope/pipelines/cv/image_super_resolution_pipeline.py +++ b/modelscope/pipelines/cv/image_super_resolution_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict import cv2 diff --git a/modelscope/pipelines/cv/image_to_image_generate_pipeline.py b/modelscope/pipelines/cv/image_to_image_generate_pipeline.py index 2a3881e7..4f0121dd 100644 --- a/modelscope/pipelines/cv/image_to_image_generate_pipeline.py +++ b/modelscope/pipelines/cv/image_to_image_generate_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp from typing import Any, Dict diff --git a/modelscope/pipelines/cv/image_to_image_translation_pipeline.py b/modelscope/pipelines/cv/image_to_image_translation_pipeline.py index 78901c9b..e5f853ca 100644 --- a/modelscope/pipelines/cv/image_to_image_translation_pipeline.py +++ b/modelscope/pipelines/cv/image_to_image_translation_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import io import os.path as osp import sys diff --git a/modelscope/pipelines/cv/movie_scene_segmentation_pipeline.py b/modelscope/pipelines/cv/movie_scene_segmentation_pipeline.py index 6704e4c0..3fffc546 100644 --- a/modelscope/pipelines/cv/movie_scene_segmentation_pipeline.py +++ b/modelscope/pipelines/cv/movie_scene_segmentation_pipeline.py @@ -60,9 +60,12 @@ class MovieSceneSegmentationPipeline(Pipeline): def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: data = {'input_video_pth': self.input_video_pth, 'feat': inputs} - video_num, meta_lst = self.model.postprocess(data) + scene_num, scene_meta_lst, shot_num, shot_meta_lst = self.model.postprocess( + data) result = { - OutputKeys.SPLIT_VIDEO_NUM: video_num, - OutputKeys.SPLIT_META_LIST: meta_lst + OutputKeys.SHOT_NUM: shot_num, + OutputKeys.SHOT_META_LIST: shot_meta_lst, + OutputKeys.SCENE_NUM: scene_num, + OutputKeys.SCENE_META_LIST: scene_meta_lst } return result diff --git a/modelscope/pipelines/cv/ocr_detection_pipeline.py b/modelscope/pipelines/cv/ocr_detection_pipeline.py index b73f65a4..292ec2c5 100644 --- a/modelscope/pipelines/cv/ocr_detection_pipeline.py +++ b/modelscope/pipelines/cv/ocr_detection_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp from typing import Any, Dict @@ -56,68 +57,72 @@ class OCRDetectionPipeline(Pipeline): model_path = osp.join( osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER), 'checkpoint-80000') - - with device_placement(self.framework, self.device_name): - config = tf.ConfigProto(allow_soft_placement=True) - config.gpu_options.allow_growth = True - self._session = tf.Session(config=config) - self.input_images = tf.placeholder( - tf.float32, shape=[1, 1024, 1024, 3], name='input_images') - self.output = {} - - with tf.variable_scope('', reuse=tf.AUTO_REUSE): - global_step = tf.get_variable( - 'global_step', [], - initializer=tf.constant_initializer(0), - dtype=tf.int64, - trainable=False) - variable_averages = tf.train.ExponentialMovingAverage( - 0.997, global_step) - - # detector - detector = SegLinkDetector() - all_maps = detector.build_model( - self.input_images, is_training=False) - - # decode local predictions - all_nodes, all_links, all_reg = [], [], [] - for i, maps in enumerate(all_maps): - cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2] - reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) - - cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) - - lnk_prob_pos = tf.nn.softmax( - tf.reshape(lnk_maps, [-1, 4])[:, :2]) - lnk_prob_mut = tf.nn.softmax( - tf.reshape(lnk_maps, [-1, 4])[:, 2:]) - lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1) - - all_nodes.append(cls_prob) - all_links.append(lnk_prob) - all_reg.append(reg_maps) - - # decode segments and links - image_size = tf.shape(self.input_images)[1:3] - segments, group_indices, segment_counts, _ = decode_segments_links_python( - image_size, - all_nodes, - all_links, - all_reg, - anchor_sizes=list(detector.anchor_sizes)) - - # combine segments - combined_rboxes, combined_counts = combine_segments_python( - segments, group_indices, segment_counts) - self.output['combined_rboxes'] = combined_rboxes - self.output['combined_counts'] = combined_counts - - with self._session.as_default() as sess: - logger.info(f'loading model from {model_path}') - # load model - model_loader = tf.train.Saver( - variable_averages.variables_to_restore()) - model_loader.restore(sess, model_path) + self._graph = tf.get_default_graph() + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + self._session = tf.Session(config=config) + + with self._graph.as_default(): + with device_placement(self.framework, self.device_name): + self.input_images = tf.placeholder( + tf.float32, shape=[1, 1024, 1024, 3], name='input_images') + self.output = {} + + with tf.variable_scope('', reuse=tf.AUTO_REUSE): + global_step = tf.get_variable( + 'global_step', [], + initializer=tf.constant_initializer(0), + dtype=tf.int64, + trainable=False) + variable_averages = tf.train.ExponentialMovingAverage( + 0.997, global_step) + + # detector + detector = SegLinkDetector() + all_maps = detector.build_model( + self.input_images, is_training=False) + + # decode local predictions + all_nodes, all_links, all_reg = [], [], [] + for i, maps in enumerate(all_maps): + cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[ + 2] + reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) + + cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) + + lnk_prob_pos = tf.nn.softmax( + tf.reshape(lnk_maps, [-1, 4])[:, :2]) + lnk_prob_mut = tf.nn.softmax( + tf.reshape(lnk_maps, [-1, 4])[:, 2:]) + lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], + axis=1) + + all_nodes.append(cls_prob) + all_links.append(lnk_prob) + all_reg.append(reg_maps) + + # decode segments and links + image_size = tf.shape(self.input_images)[1:3] + segments, group_indices, segment_counts, _ = decode_segments_links_python( + image_size, + all_nodes, + all_links, + all_reg, + anchor_sizes=list(detector.anchor_sizes)) + + # combine segments + combined_rboxes, combined_counts = combine_segments_python( + segments, group_indices, segment_counts) + self.output['combined_rboxes'] = combined_rboxes + self.output['combined_counts'] = combined_counts + + with self._session.as_default() as sess: + logger.info(f'loading model from {model_path}') + # load model + model_loader = tf.train.Saver( + variable_averages.variables_to_restore()) + model_loader.restore(sess, model_path) def preprocess(self, input: Input) -> Dict[str, Any]: img = LoadImage.convert_to_ndarray(input) @@ -132,19 +137,22 @@ class OCRDetectionPipeline(Pipeline): img_pad_resize = img_pad_resize - np.array([123.68, 116.78, 103.94], dtype=np.float32) - resize_size = tf.stack([resize_size, resize_size]) - orig_size = tf.stack([max(h, w), max(h, w)]) - self.output['orig_size'] = orig_size - self.output['resize_size'] = resize_size + with self._graph.as_default(): + resize_size = tf.stack([resize_size, resize_size]) + orig_size = tf.stack([max(h, w), max(h, w)]) + self.output['orig_size'] = orig_size + self.output['resize_size'] = resize_size result = {'img': np.expand_dims(img_pad_resize, axis=0)} return result def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - with self._session.as_default(): - feed_dict = {self.input_images: input['img']} - sess_outputs = self._session.run(self.output, feed_dict=feed_dict) - return sess_outputs + with self._graph.as_default(): + with self._session.as_default(): + feed_dict = {self.input_images: input['img']} + sess_outputs = self._session.run( + self.output, feed_dict=feed_dict) + return sess_outputs def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: rboxes = inputs['combined_rboxes'][0] diff --git a/modelscope/pipelines/cv/ocr_recognition_pipeline.py b/modelscope/pipelines/cv/ocr_recognition_pipeline.py index c20d020c..e81467a1 100644 --- a/modelscope/pipelines/cv/ocr_recognition_pipeline.py +++ b/modelscope/pipelines/cv/ocr_recognition_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import math import os.path as osp from typing import Any, Dict diff --git a/modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py b/modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py index 2614983b..0164a998 100644 --- a/modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py +++ b/modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp from typing import Any, Dict diff --git a/modelscope/pipelines/cv/product_segmentation_pipeline.py b/modelscope/pipelines/cv/product_segmentation_pipeline.py new file mode 100644 index 00000000..244b01d7 --- /dev/null +++ b/modelscope/pipelines/cv/product_segmentation_pipeline.py @@ -0,0 +1,40 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.product_segmentation import seg_infer +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.product_segmentation, module_name=Pipelines.product_segmentation) +class F3NetForProductSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create product segmentation pipeline for prediction + Args: + model: model id on modelscope hub. + """ + + super().__init__(model=model, **kwargs) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + mask = seg_infer.inference(self.model, self.device, + input['input_path']) + return {OutputKeys.MASKS: mask} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/skin_retouching_pipeline.py b/modelscope/pipelines/cv/skin_retouching_pipeline.py index f8c9de60..c6571bef 100644 --- a/modelscope/pipelines/cv/skin_retouching_pipeline.py +++ b/modelscope/pipelines/cv/skin_retouching_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os from typing import Any, Dict diff --git a/modelscope/pipelines/cv/video_inpainting_pipeline.py b/modelscope/pipelines/cv/video_inpainting_pipeline.py index 15444e05..85133474 100644 --- a/modelscope/pipelines/cv/video_inpainting_pipeline.py +++ b/modelscope/pipelines/cv/video_inpainting_pipeline.py @@ -1,3 +1,4 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. from typing import Any, Dict from modelscope.metainfo import Pipelines diff --git a/modelscope/pipelines/cv/video_summarization_pipeline.py b/modelscope/pipelines/cv/video_summarization_pipeline.py index 001780e1..25ea1e7c 100644 --- a/modelscope/pipelines/cv/video_summarization_pipeline.py +++ b/modelscope/pipelines/cv/video_summarization_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp from typing import Any, Dict diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 5267b5b2..be854593 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -4,12 +4,13 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: + from .automatic_post_editing_pipeline import AutomaticPostEditingPipeline from .conversational_text_to_sql_pipeline import ConversationalTextToSqlPipeline - from .table_question_answering_pipeline import TableQuestionAnsweringPipeline from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline from .dialog_modeling_pipeline import DialogModelingPipeline from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline from .document_segmentation_pipeline import DocumentSegmentationPipeline + from .fasttext_sequence_classification_pipeline import FasttextSequenceClassificationPipeline from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline from .feature_extraction_pipeline import FeatureExtractionPipeline from .fill_mask_pipeline import FillMaskPipeline @@ -20,6 +21,8 @@ if TYPE_CHECKING: from .sentence_embedding_pipeline import SentenceEmbeddingPipeline from .sequence_classification_pipeline import SequenceClassificationPipeline from .summarization_pipeline import SummarizationPipeline + from .table_question_answering_pipeline import TableQuestionAnsweringPipeline + from .translation_quality_estimation_pipeline import TranslationQualityEstimationPipeline from .text_classification_pipeline import TextClassificationPipeline from .text_error_correction_pipeline import TextErrorCorrectionPipeline from .text_generation_pipeline import TextGenerationPipeline @@ -31,14 +34,15 @@ if TYPE_CHECKING: else: _import_structure = { + 'automatic_post_editing_pipeline': ['AutomaticPostEditingPipeline'], 'conversational_text_to_sql_pipeline': ['ConversationalTextToSqlPipeline'], - 'table_question_answering_pipeline': - ['TableQuestionAnsweringPipeline'], 'dialog_intent_prediction_pipeline': ['DialogIntentPredictionPipeline'], 'dialog_modeling_pipeline': ['DialogModelingPipeline'], 'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'], + 'domain_classification_pipeline': + ['FasttextSequenceClassificationPipeline'], 'document_segmentation_pipeline': ['DocumentSegmentationPipeline'], 'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'], 'feature_extraction_pipeline': ['FeatureExtractionPipeline'], @@ -51,12 +55,16 @@ else: 'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline'], 'sequence_classification_pipeline': ['SequenceClassificationPipeline'], 'summarization_pipeline': ['SummarizationPipeline'], + 'table_question_answering_pipeline': + ['TableQuestionAnsweringPipeline'], 'text_classification_pipeline': ['TextClassificationPipeline'], 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], 'text_generation_pipeline': ['TextGenerationPipeline'], 'text2text_generation_pipeline': ['Text2TextGenerationPipeline'], 'token_classification_pipeline': ['TokenClassificationPipeline'], 'translation_pipeline': ['TranslationPipeline'], + 'translation_quality_estimation_pipeline': + ['TranslationQualityEstimationPipeline'], 'word_segmentation_pipeline': ['WordSegmentationPipeline'], 'zero_shot_classification_pipeline': ['ZeroShotClassificationPipeline'], diff --git a/modelscope/pipelines/nlp/automatic_post_editing_pipeline.py b/modelscope/pipelines/nlp/automatic_post_editing_pipeline.py new file mode 100644 index 00000000..83968586 --- /dev/null +++ b/modelscope/pipelines/nlp/automatic_post_editing_pipeline.py @@ -0,0 +1,158 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from html import unescape +from typing import Any, Dict + +import jieba +import numpy as np +import tensorflow as tf +from sacremoses import (MosesDetokenizer, MosesDetruecaser, + MosesPunctNormalizer, MosesTokenizer, MosesTruecaser) +from sentencepiece import SentencePieceProcessor +from tensorflow.contrib.seq2seq.python.ops import beam_search_ops + +from modelscope.metainfo import Pipelines +from modelscope.models.base import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.config import Config, ConfigFields +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + tf.disable_eager_execution() + +logger = get_logger() + +__all__ = ['AutomaticPostEditingPipeline'] + + +@PIPELINES.register_module( + Tasks.translation, module_name=Pipelines.automatic_post_editing) +class AutomaticPostEditingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """Build an automatic post editing pipeline with a model dir. + + @param model: Model path for saved pb file + """ + super().__init__(model=model, **kwargs) + export_dir = model + self.cfg = Config.from_file( + os.path.join(export_dir, ModelFile.CONFIGURATION)) + joint_vocab_file = os.path.join( + export_dir, self.cfg[ConfigFields.preprocessor]['vocab']) + self.vocab = dict([(w.strip(), i) for i, w in enumerate( + open(joint_vocab_file, 'r', encoding='utf8'))]) + self.vocab_reverse = dict([(i, w.strip()) for i, w in enumerate( + open(joint_vocab_file, 'r', encoding='utf8'))]) + self.unk_id = self.cfg[ConfigFields.preprocessor].get('unk_id', -1) + strip_unk = self.cfg.get(ConfigFields.postprocessor, + {}).get('strip_unk', True) + self.unk_token = '' if strip_unk else self.cfg.get( + ConfigFields.postprocessor, {}).get('unk_token', '') + if self.unk_id == -1: + self.unk_id = len(self.vocab) - 1 + tf.reset_default_graph() + tf_config = tf.ConfigProto(allow_soft_placement=True) + tf_config.gpu_options.allow_growth = True + self._session = tf.Session(config=tf_config) + tf.saved_model.loader.load( + self._session, [tf.python.saved_model.tag_constants.SERVING], + export_dir) + default_graph = tf.get_default_graph() + self.input_src_id_placeholder = default_graph.get_tensor_by_name( + 'Placeholder:0') + self.input_src_len_placeholder = default_graph.get_tensor_by_name( + 'Placeholder_1:0') + self.input_mt_id_placeholder = default_graph.get_tensor_by_name( + 'Placeholder_2:0') + self.input_mt_len_placeholder = default_graph.get_tensor_by_name( + 'Placeholder_3:0') + output_id_beam = default_graph.get_tensor_by_name( + 'enc2enc/decoder/transpose:0') + output_len_beam = default_graph.get_tensor_by_name( + 'enc2enc/decoder/Minimum:0') + output_id = tf.cast( + tf.map_fn(lambda x: x[0], output_id_beam), dtype=tf.int64) + output_len = tf.map_fn(lambda x: x[0], output_len_beam) + self.output = {'output_ids': output_id, 'output_lens': output_len} + init = tf.global_variables_initializer() + local_init = tf.local_variables_initializer() + self._session.run([init, local_init]) + tf.saved_model.loader.load( + self._session, [tf.python.saved_model.tag_constants.SERVING], + export_dir) + + # preprocess + self._src_lang = self.cfg[ConfigFields.preprocessor]['src_lang'] + self._tgt_lang = self.cfg[ConfigFields.preprocessor]['tgt_lang'] + tok_escape = self.cfg[ConfigFields.preprocessor].get( + 'tokenize_escape', False) + src_tokenizer = MosesTokenizer(lang=self._src_lang) + mt_tokenizer = MosesTokenizer(lang=self._tgt_lang) + truecase_model = os.path.join( + export_dir, self.cfg[ConfigFields.preprocessor]['truecaser']) + truecaser = MosesTruecaser(load_from=truecase_model) + sp_model = os.path.join( + export_dir, self.cfg[ConfigFields.preprocessor]['sentencepiece']) + sp = SentencePieceProcessor() + sp.load(sp_model) + + self.src_preprocess = lambda x: ' '.join( + sp.encode_as_pieces( + truecaser.truecase( + src_tokenizer.tokenize( + x, return_str=True, escape=tok_escape), + return_str=True))) + self.mt_preprocess = lambda x: ' '.join( + sp.encode_as_pieces( + truecaser.truecase( + mt_tokenizer.tokenize( + x, return_str=True, escape=tok_escape), + return_str=True))) + + # post process, de-bpe, de-truecase, detok + detruecaser = MosesDetruecaser() + detokenizer = MosesDetokenizer(lang=self._tgt_lang) + self.postprocess_fun = lambda x: detokenizer.detokenize( + detruecaser.detruecase( + x.replace(' ▁', '@@').replace(' ', '').replace('@@', ' '). + strip()[1:], + return_str=True).split()) + + def preprocess(self, input: str) -> Dict[str, Any]: + src, mt = input.split('\005', 1) + src_sp, mt_sp = self.src_preprocess(src), self.mt_preprocess(mt) + input_src_ids = np.array( + [[self.vocab.get(w, self.unk_id) for w in src_sp.strip().split()]]) + input_mt_ids = np.array( + [[self.vocab.get(w, self.unk_id) for w in mt_sp.strip().split()]]) + input_src_lens = [len(x) for x in input_src_ids] + input_mt_lens = [len(x) for x in input_mt_ids] + feed_dict = { + self.input_src_id_placeholder: input_src_ids, + self.input_mt_id_placeholder: input_mt_ids, + self.input_src_len_placeholder: input_src_lens, + self.input_mt_len_placeholder: input_mt_lens + } + return feed_dict + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + with self._session.as_default(): + sess_outputs = self._session.run(self.output, feed_dict=input) + return sess_outputs + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + output_ids, output_len = inputs['output_ids'][0], inputs[ + 'output_lens'][0] + output_ids = output_ids[:output_len - 1] # -1 for + output_tokens = ' '.join([ + self.vocab_reverse.get(wid, self.unk_token) for wid in output_ids + ]) + post_editing_output = self.postprocess_fun(output_tokens) + result = {OutputKeys.TRANSLATION: post_editing_output} + return result diff --git a/modelscope/pipelines/nlp/fasttext_sequence_classification_pipeline.py b/modelscope/pipelines/nlp/fasttext_sequence_classification_pipeline.py new file mode 100644 index 00000000..f10af88f --- /dev/null +++ b/modelscope/pipelines/nlp/fasttext_sequence_classification_pipeline.py @@ -0,0 +1,69 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict, Union + +import numpy as np +import sentencepiece +from fasttext import load_model +from fasttext.FastText import _FastText + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import SequenceClassificationPreprocessor +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['FasttextSequenceClassificationPipeline'] + + +def sentencepiece_tokenize(sp_model, sent): + tokens = [] + for t in sp_model.EncodeAsPieces(sent): + s = t.strip() + if s: + tokens.append(s) + return ' '.join(tokens) + + +@PIPELINES.register_module( + Tasks.text_classification, module_name=Pipelines.domain_classification) +class FasttextSequenceClassificationPipeline(Pipeline): + + def __init__(self, model: Union[str, _FastText], **kwargs): + """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + + Args: + model: a model directory including model.bin and spm.model + preprocessor (SequenceClassificationPreprocessor): a preprocessor instance + """ + super().__init__(model=model) + model_file = os.path.join(model, ModelFile.TORCH_MODEL_BIN_FILE) + spm_file = os.path.join(model, 'sentencepiece.model') + assert os.path.isdir(model) and os.path.exists(model_file) and os.path.exists(spm_file), \ + '`model` should be a directory contains `model.bin` and `sentencepiece.model`' + self.model = load_model(model_file) + self.spm = sentencepiece.SentencePieceProcessor() + self.spm.Load(spm_file) + + def preprocess(self, inputs: str) -> Dict[str, Any]: + text = inputs.strip() + text_sp = sentencepiece_tokenize(self.spm, text) + return {'text_sp': text_sp, 'text': text} + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + topk = inputs.get('topk', -1) + label, probs = self.model.predict(inputs['text_sp'], k=topk) + label = [x.replace('__label__', '') for x in label] + result = { + OutputKeys.LABEL: label[0], + OutputKeys.SCORE: probs[0], + OutputKeys.LABELS: label, + OutputKeys.SCORES: probs + } + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/nlp/feature_extraction_pipeline.py b/modelscope/pipelines/nlp/feature_extraction_pipeline.py index 3af0c28d..e94e4337 100644 --- a/modelscope/pipelines/nlp/feature_extraction_pipeline.py +++ b/modelscope/pipelines/nlp/feature_extraction_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os from typing import Any, Dict, Optional, Union diff --git a/modelscope/pipelines/nlp/sequence_classification_pipeline.py b/modelscope/pipelines/nlp/sequence_classification_pipeline.py index 8d0e1dcd..69f6217a 100644 --- a/modelscope/pipelines/nlp/sequence_classification_pipeline.py +++ b/modelscope/pipelines/nlp/sequence_classification_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, Union import numpy as np diff --git a/modelscope/pipelines/nlp/text2text_generation_pipeline.py b/modelscope/pipelines/nlp/text2text_generation_pipeline.py index 9ccd00f4..21aacf54 100644 --- a/modelscope/pipelines/nlp/text2text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text2text_generation_pipeline.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, Optional, Union import torch diff --git a/modelscope/pipelines/nlp/translation_quality_estimation_pipeline.py b/modelscope/pipelines/nlp/translation_quality_estimation_pipeline.py new file mode 100644 index 00000000..6ef203b9 --- /dev/null +++ b/modelscope/pipelines/nlp/translation_quality_estimation_pipeline.py @@ -0,0 +1,72 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import io +import os +from typing import Any, Dict, Union + +import numpy as np +import torch +from transformers import XLMRobertaTokenizer + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.nlp import BertForSequenceClassification +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import SequenceClassificationPreprocessor +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['TranslationQualityEstimationPipeline'] + + +@PIPELINES.register_module( + Tasks.sentence_similarity, + module_name=Pipelines.translation_quality_estimation) +class TranslationQualityEstimationPipeline(Pipeline): + + def __init__(self, model: str, device: str = 'gpu', **kwargs): + super().__init__(model=model, device=device) + model_file = os.path.join(model, ModelFile.TORCH_MODEL_FILE) + with open(model_file, 'rb') as f: + buffer = io.BytesIO(f.read()) + self.tokenizer = XLMRobertaTokenizer.from_pretrained(model) + self.model = torch.jit.load( + buffer, map_location=self.device).to(self.device) + + def preprocess(self, inputs: Dict[str, Any]): + src_text = inputs['source_text'].strip() + tgt_text = inputs['target_text'].strip() + encoded_inputs = self.tokenizer.batch_encode_plus( + [[src_text, tgt_text]], + return_tensors='pt', + padding=True, + truncation=True) + input_ids = encoded_inputs['input_ids'].to(self.device) + attention_mask = encoded_inputs['attention_mask'].to(self.device) + inputs.update({ + 'input_ids': input_ids, + 'attention_mask': attention_mask + }) + return inputs + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + if 'input_ids' not in inputs: + inputs = self.preprocess(inputs) + res = self.model(inputs['input_ids'], inputs['attention_mask']) + result = { + OutputKeys.LABELS: '-1', + OutputKeys.SCORES: res[0].detach().squeeze().tolist() + } + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): input data dict + + Returns: + Dict[str, str]: the prediction results + """ + return inputs diff --git a/modelscope/preprocessors/ofa/utils/collate.py b/modelscope/preprocessors/ofa/utils/collate.py index b128c3fb..f7775680 100644 --- a/modelscope/preprocessors/ofa/utils/collate.py +++ b/modelscope/preprocessors/ofa/utils/collate.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import numpy as np import torch diff --git a/modelscope/preprocessors/ofa/utils/random_help.py b/modelscope/preprocessors/ofa/utils/random_help.py index 77f4df3f..e0dca54e 100644 --- a/modelscope/preprocessors/ofa/utils/random_help.py +++ b/modelscope/preprocessors/ofa/utils/random_help.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import torch try: diff --git a/modelscope/trainers/multi_modal/clip/__init__.py b/modelscope/trainers/multi_modal/clip/__init__.py index 87f1040c..61a6664b 100644 --- a/modelscope/trainers/multi_modal/clip/__init__.py +++ b/modelscope/trainers/multi_modal/clip/__init__.py @@ -1 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + from .clip_trainer import CLIPTrainer diff --git a/modelscope/trainers/multi_modal/clip/clip_trainer.py b/modelscope/trainers/multi_modal/clip/clip_trainer.py index cccf4296..cbe83417 100644 --- a/modelscope/trainers/multi_modal/clip/clip_trainer.py +++ b/modelscope/trainers/multi_modal/clip/clip_trainer.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Dict, Optional diff --git a/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py b/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py index 1391a4fd..4e150fe7 100644 --- a/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py +++ b/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os import random diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index b19c0fce..5bc27c03 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -43,6 +43,9 @@ class CVTasks(object): text_driven_segmentation = 'text-driven-segmentation' shop_segmentation = 'shop-segmentation' hand_static = 'hand-static' + face_human_hand_detection = 'face-human-hand-detection' + face_emotion = 'face-emotion' + product_segmentation = 'product-segmentation' # image editing skin_retouching = 'skin-retouching' @@ -264,6 +267,7 @@ class ConfigFields(object): preprocessor = 'preprocessor' train = 'train' evaluation = 'evaluation' + postprocessor = 'postprocessor' class ConfigKeys(object): diff --git a/requirements/nlp.txt b/requirements/nlp.txt index 15f2f41a..f18dde2e 100644 --- a/requirements/nlp.txt +++ b/requirements/nlp.txt @@ -1,7 +1,10 @@ en_core_web_sm>=2.3.5 +fasttext jieba>=0.42.1 megatron_util pai-easynlp +# “protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged.” +protobuf>=3.19.0,<3.21.0 # rough-score was just recently updated from 0.0.4 to 0.0.7 # which introduced compatability issues that are being investigated rouge_score<=0.0.4 diff --git a/tests/pipelines/test_action_recognition.py b/tests/pipelines/test_action_recognition.py index b9548630..292eb238 100644 --- a/tests/pipelines/test_action_recognition.py +++ b/tests/pipelines/test_action_recognition.py @@ -29,6 +29,14 @@ class ActionRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): print(f'recognition output: {result}.') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_pst(self): + pst_recognition_pipeline = pipeline( + self.task, model='damo/cv_pathshift_action-recognition') + result = pst_recognition_pipeline( + 'data/test/videos/action_recognition_test_video.mp4') + print('pst recognition results:', result) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_demo_compatibility(self): self.compatibility_check() diff --git a/tests/pipelines/test_automatic_post_editing.py b/tests/pipelines/test_automatic_post_editing.py new file mode 100644 index 00000000..da09851c --- /dev/null +++ b/tests/pipelines/test_automatic_post_editing.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class AutomaticPostEditingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.translation + self.model_id = 'damo/nlp_automatic_post_editing_for_translation_en2de' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_en2de(self): + inputs = 'Simultaneously, the Legion took part to the pacification of Algeria, plagued by various tribal ' \ + 'rebellions and razzias.\005Gleichzeitig nahm die Legion an der Befriedung Algeriens teil, die von ' \ + 'verschiedenen Stammesaufständen und Rasias heimgesucht wurde.' + pipeline_ins = pipeline(self.task, model=self.model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_body_3d_keypoints.py b/tests/pipelines/test_body_3d_keypoints.py index 9dce0d19..6f27f12d 100644 --- a/tests/pipelines/test_body_3d_keypoints.py +++ b/tests/pipelines/test_body_3d_keypoints.py @@ -20,7 +20,7 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck): self.task = Tasks.body_3d_keypoints def pipeline_inference(self, pipeline: Pipeline, pipeline_input): - output = pipeline(pipeline_input) + output = pipeline(pipeline_input, output_video='./result.mp4') poses = np.array(output[OutputKeys.POSES]) print(f'result 3d points shape {poses.shape}') @@ -28,7 +28,9 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck): def test_run_modelhub_with_video_file(self): body_3d_keypoints = pipeline( Tasks.body_3d_keypoints, model=self.model_id) - self.pipeline_inference(body_3d_keypoints, self.test_video) + pipeline_input = self.test_video + self.pipeline_inference( + body_3d_keypoints, pipeline_input=pipeline_input) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_modelhub_with_video_stream(self): @@ -37,12 +39,9 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck): if not cap.isOpened(): raise Exception('modelscope error: %s cannot be decoded by OpenCV.' % (self.test_video)) - self.pipeline_inference(body_3d_keypoints, cap) - - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_run_modelhub_default_model(self): - body_3d_keypoints = pipeline(Tasks.body_3d_keypoints) - self.pipeline_inference(body_3d_keypoints, self.test_video) + pipeline_input = self.test_video + self.pipeline_inference( + body_3d_keypoints, pipeline_input=pipeline_input) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_demo_compatibility(self): diff --git a/tests/pipelines/test_domain_classification.py b/tests/pipelines/test_domain_classification.py new file mode 100644 index 00000000..8e5bfa7f --- /dev/null +++ b/tests/pipelines/test_domain_classification.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class DomainClassificationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.text_classification + self.model_id = 'damo/nlp_domain_classification_chinese' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_zh_domain(self): + inputs = '通过这种方式产生的离子吸收大地水分之后,可以通过潮解作用,将活性电解离子有效释放到周围土壤中,使接地极成为一个离子发生装置,' \ + '从而改善周边土质使之达到接地要求。' + pipeline_ins = pipeline(self.task, model=self.model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_zh_style(self): + model_id = 'damo/nlp_style_classification_chinese' + inputs = '通过这种方式产生的离子吸收大地水分之后,可以通过潮解作用,将活性电解离子有效释放到周围土壤中,使接地极成为一个离子发生装置,' \ + '从而改善周边土质使之达到接地要求。' + pipeline_ins = pipeline(self.task, model=model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_en_style(self): + model_id = 'damo/nlp_style_classification_english' + inputs = 'High Power 11.1V 5200mAh Lipo Battery For RC Car Robot Airplanes ' \ + 'Helicopter RC Drone Parts 3s Lithium battery 11.1v Battery' + pipeline_ins = pipeline(self.task, model=model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_face_emotion.py b/tests/pipelines/test_face_emotion.py new file mode 100644 index 00000000..907e15ee --- /dev/null +++ b/tests/pipelines/test_face_emotion.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class FaceEmotionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model = 'damo/cv_face-emotion' + self.img = {'img_path': 'data/test/images/face_emotion.jpg'} + + def pipeline_inference(self, pipeline: Pipeline, input: str): + result = pipeline(input) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_emotion = pipeline(Tasks.face_emotion, model=self.model) + self.pipeline_inference(face_emotion, self.img) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + face_emotion = pipeline(Tasks.face_emotion) + self.pipeline_inference(face_emotion, self.img) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_face_human_hand_detection.py b/tests/pipelines/test_face_human_hand_detection.py new file mode 100644 index 00000000..7aaa67e7 --- /dev/null +++ b/tests/pipelines/test_face_human_hand_detection.py @@ -0,0 +1,38 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class FaceHumanHandTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_nanodet_face-human-hand-detection' + self.input = { + 'input_path': 'data/test/images/face_human_hand_detection.jpg', + } + + def pipeline_inference(self, pipeline: Pipeline, input: str): + result = pipeline(input) + logger.info(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_human_hand_detection = pipeline( + Tasks.face_human_hand_detection, model=self.model_id) + self.pipeline_inference(face_human_hand_detection, self.input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + face_human_hand_detection = pipeline(Tasks.face_human_hand_detection) + self.pipeline_inference(face_human_hand_detection, self.input) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_mplug_tasks.py b/tests/pipelines/test_mplug_tasks.py index a3ace62d..11c9798f 100644 --- a/tests/pipelines/test_mplug_tasks.py +++ b/tests/pipelines/test_mplug_tasks.py @@ -26,7 +26,7 @@ class MplugTasksTest(unittest.TestCase, DemoCompatibilityCheck): model=model, ) image = Image.open('data/test/images/image_mplug_vqa.jpg') - result = pipeline_caption({'image': image}) + result = pipeline_caption(image) print(result[OutputKeys.CAPTION]) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -35,7 +35,7 @@ class MplugTasksTest(unittest.TestCase, DemoCompatibilityCheck): Tasks.image_captioning, model='damo/mplug_image-captioning_coco_base_en') image = Image.open('data/test/images/image_mplug_vqa.jpg') - result = pipeline_caption({'image': image}) + result = pipeline_caption(image) print(result[OutputKeys.CAPTION]) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index d89e5d48..104c2869 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -34,7 +34,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): model=model, ) image = 'data/test/images/image_captioning.png' - result = img_captioning({'image': image}) + result = img_captioning(image) print(result[OutputKeys.CAPTION]) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -42,8 +42,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): img_captioning = pipeline( Tasks.image_captioning, model='damo/ofa_image-caption_coco_large_en') - result = img_captioning( - {'image': 'data/test/images/image_captioning.png'}) + result = img_captioning('data/test/images/image_captioning.png') print(result[OutputKeys.CAPTION]) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @@ -52,8 +51,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): 'damo/ofa_image-classification_imagenet_large_en') ofa_pipe = pipeline(Tasks.image_classification, model=model) image = 'data/test/images/image_classification.png' - input = {'image': image} - result = ofa_pipe(input) + result = ofa_pipe(image) print(result) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -62,8 +60,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): Tasks.image_classification, model='damo/ofa_image-classification_imagenet_large_en') image = 'data/test/images/image_classification.png' - input = {'image': image} - result = ofa_pipe(input) + result = ofa_pipe(image) print(result) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @@ -102,8 +99,8 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): ofa_pipe = pipeline(Tasks.text_classification, model=model) text = 'One of our number will carry out your instructions minutely.' text2 = 'A member of my team will execute your orders with immense precision.' - input = {'text': text, 'text2': text2} - result = ofa_pipe(input) + result = ofa_pipe((text, text2)) + result = ofa_pipe({'text': text, 'text2': text2}) print(result) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -113,8 +110,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): model='damo/ofa_text-classification_mnli_large_en') text = 'One of our number will carry out your instructions minutely.' text2 = 'A member of my team will execute your orders with immense precision.' - input = {'text': text, 'text2': text2} - result = ofa_pipe(input) + result = ofa_pipe((text, text2)) print(result) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') diff --git a/tests/pipelines/test_product_segmentation.py b/tests/pipelines/test_product_segmentation.py new file mode 100644 index 00000000..8f41c13c --- /dev/null +++ b/tests/pipelines/test_product_segmentation.py @@ -0,0 +1,43 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class ProductSegmentationTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_F3Net_product-segmentation' + self.input = { + 'input_path': 'data/test/images/product_segmentation.jpg' + } + + def pipeline_inference(self, pipeline: Pipeline, input: str): + result = pipeline(input) + cv2.imwrite('test_product_segmentation_mask.jpg', + result[OutputKeys.MASKS]) + logger.info('test done') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + product_segmentation = pipeline( + Tasks.product_segmentation, model=self.model_id) + self.pipeline_inference(product_segmentation, self.input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + product_segmentation = pipeline(Tasks.product_segmentation) + self.pipeline_inference(product_segmentation, self.input) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_translation_quality_estimation.py b/tests/pipelines/test_translation_quality_estimation.py new file mode 100644 index 00000000..315fa72b --- /dev/null +++ b/tests/pipelines/test_translation_quality_estimation.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class TranslationQualityEstimationTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.sentence_similarity + self.model_id = 'damo/nlp_translation_quality_estimation_multilingual' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_en2zh(self): + inputs = { + 'source_text': 'Love is a losing game', + 'target_text': '宝贝,人和人一场游戏' + } + pipeline_ins = pipeline(self.task, model=self.model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main()