达摩行为识别合入maas lib
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9134444
master
| @@ -39,6 +39,7 @@ class Pipelines(object): | |||
| image_matting = 'unet-image-matting' | |||
| person_image_cartoon = 'unet-person-image-cartoon' | |||
| ocr_detection = 'resnet18-ocr-detection' | |||
| action_recognition = 'TAdaConv_action-recognition' | |||
| # nlp tasks | |||
| sentence_similarity = 'sentence-similarity' | |||
| @@ -0,0 +1,91 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| from .tada_convnext import TadaConvNeXt | |||
| class BaseVideoModel(nn.Module): | |||
| """ | |||
| Standard video model. | |||
| The model is divided into the backbone and the head, where the backbone | |||
| extracts features and the head performs classification. | |||
| The backbones can be defined in model/base/backbone.py or anywhere else | |||
| as long as the backbone is registered by the BACKBONE_REGISTRY. | |||
| The heads can be defined in model/module_zoo/heads/ or anywhere else | |||
| as long as the head is registered by the HEAD_REGISTRY. | |||
| The registries automatically finds the registered modules and construct | |||
| the base video model. | |||
| """ | |||
| def __init__(self, cfg): | |||
| """ | |||
| Args: | |||
| cfg (Config): global config object. | |||
| """ | |||
| super(BaseVideoModel, self).__init__() | |||
| # the backbone is created according to meta-architectures | |||
| # defined in models/base/backbone.py | |||
| self.backbone = TadaConvNeXt(cfg) | |||
| # the head is created according to the heads | |||
| # defined in models/module_zoo/heads | |||
| self.head = BaseHead(cfg) | |||
| def forward(self, x): | |||
| x = self.backbone(x) | |||
| x = self.head(x) | |||
| return x | |||
| class BaseHead(nn.Module): | |||
| """ | |||
| Constructs base head. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| cfg, | |||
| ): | |||
| """ | |||
| Args: | |||
| cfg (Config): global config object. | |||
| """ | |||
| super(BaseHead, self).__init__() | |||
| self.cfg = cfg | |||
| dim = cfg.VIDEO.BACKBONE.NUM_OUT_FEATURES | |||
| num_classes = cfg.VIDEO.HEAD.NUM_CLASSES | |||
| dropout_rate = cfg.VIDEO.HEAD.DROPOUT_RATE | |||
| activation_func = cfg.VIDEO.HEAD.ACTIVATION | |||
| self._construct_head(dim, num_classes, dropout_rate, activation_func) | |||
| def _construct_head(self, dim, num_classes, dropout_rate, activation_func): | |||
| self.global_avg_pool = nn.AdaptiveAvgPool3d(1) | |||
| if dropout_rate > 0.0: | |||
| self.dropout = nn.Dropout(dropout_rate) | |||
| self.out = nn.Linear(dim, num_classes, bias=True) | |||
| if activation_func == 'softmax': | |||
| self.activation = nn.Softmax(dim=-1) | |||
| elif activation_func == 'sigmoid': | |||
| self.activation = nn.Sigmoid() | |||
| else: | |||
| raise NotImplementedError('{} is not supported as an activation' | |||
| 'function.'.format(activation_func)) | |||
| def forward(self, x): | |||
| if len(x.shape) == 5: | |||
| x = self.global_avg_pool(x) | |||
| # (N, C, T, H, W) -> (N, T, H, W, C). | |||
| x = x.permute((0, 2, 3, 4, 1)) | |||
| if hasattr(self, 'dropout'): | |||
| out = self.dropout(x) | |||
| else: | |||
| out = x | |||
| out = self.out(out) | |||
| out = self.activation(out) | |||
| out = out.view(out.shape[0], -1) | |||
| return out, x.view(x.shape[0], -1) | |||
| @@ -0,0 +1,472 @@ | |||
| import math | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from torch.nn.modules.utils import _pair, _triple | |||
| def drop_path(x, drop_prob: float = 0., training: bool = False): | |||
| """ | |||
| From https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py. | |||
| Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
| This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | |||
| the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... | |||
| See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | |||
| changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use | |||
| 'survival rate' as the argument. | |||
| """ | |||
| if drop_prob == 0. or not training: | |||
| return x | |||
| keep_prob = 1 - drop_prob | |||
| shape = (x.shape[0], ) + (1, ) * ( | |||
| x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | |||
| random_tensor = keep_prob + torch.rand( | |||
| shape, dtype=x.dtype, device=x.device) | |||
| random_tensor.floor_() # binarize | |||
| output = x.div(keep_prob) * random_tensor | |||
| return output | |||
| class DropPath(nn.Module): | |||
| """ | |||
| From https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py. | |||
| Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
| """ | |||
| def __init__(self, drop_prob=None): | |||
| super(DropPath, self).__init__() | |||
| self.drop_prob = drop_prob | |||
| def forward(self, x): | |||
| return drop_path(x, self.drop_prob, self.training) | |||
| class TadaConvNeXt(nn.Module): | |||
| r""" ConvNeXt | |||
| A PyTorch impl of : `A ConvNet for the 2020s` - | |||
| https://arxiv.org/pdf/2201.03545.pdf | |||
| Args: | |||
| in_chans (int): Number of input image channels. Default: 3 | |||
| num_classes (int): Number of classes for classification head. Default: 1000 | |||
| depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] | |||
| dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] | |||
| drop_path_rate (float): Stochastic depth rate. Default: 0. | |||
| layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. | |||
| head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. | |||
| """ | |||
| def __init__( | |||
| self, cfg | |||
| # in_chans=3, num_classes=1000, | |||
| # depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., | |||
| # layer_scale_init_value=1e-6, head_init_scale=1., | |||
| ): | |||
| super().__init__() | |||
| in_chans = cfg.VIDEO.BACKBONE.NUM_INPUT_CHANNELS | |||
| dims = cfg.VIDEO.BACKBONE.NUM_FILTERS | |||
| drop_path_rate = cfg.VIDEO.BACKBONE.DROP_PATH | |||
| depths = cfg.VIDEO.BACKBONE.DEPTH | |||
| layer_scale_init_value = cfg.VIDEO.BACKBONE.LARGE_SCALE_INIT_VALUE | |||
| stem_t_kernel_size = cfg.VIDEO.BACKBONE.STEM.T_KERNEL_SIZE if hasattr( | |||
| cfg.VIDEO.BACKBONE.STEM, 'T_KERNEL_SIZE') else 2 | |||
| t_stride = cfg.VIDEO.BACKBONE.STEM.T_STRIDE if hasattr( | |||
| cfg.VIDEO.BACKBONE.STEM, 'T_STRIDE') else 2 | |||
| self.downsample_layers = nn.ModuleList( | |||
| ) # stem and 3 intermediate downsampling conv layers | |||
| stem = nn.Sequential( | |||
| nn.Conv3d( | |||
| in_chans, | |||
| dims[0], | |||
| kernel_size=(stem_t_kernel_size, 4, 4), | |||
| stride=(t_stride, 4, 4), | |||
| padding=((stem_t_kernel_size - 1) // 2, 0, 0)), | |||
| LayerNorm(dims[0], eps=1e-6, data_format='channels_first')) | |||
| self.downsample_layers.append(stem) | |||
| for i in range(3): | |||
| downsample_layer = nn.Sequential( | |||
| LayerNorm(dims[i], eps=1e-6, data_format='channels_first'), | |||
| nn.Conv3d( | |||
| dims[i], | |||
| dims[i + 1], | |||
| kernel_size=(1, 2, 2), | |||
| stride=(1, 2, 2)), | |||
| ) | |||
| self.downsample_layers.append(downsample_layer) | |||
| self.stages = nn.ModuleList( | |||
| ) # 4 feature resolution stages, each consisting of multiple residual blocks | |||
| dp_rates = [ | |||
| x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) | |||
| ] | |||
| cur = 0 | |||
| for i in range(4): | |||
| stage = nn.Sequential(*[ | |||
| TAdaConvNeXtBlock( | |||
| cfg, | |||
| dim=dims[i], | |||
| drop_path=dp_rates[cur + j], | |||
| layer_scale_init_value=layer_scale_init_value) | |||
| for j in range(depths[i]) | |||
| ]) | |||
| self.stages.append(stage) | |||
| cur += depths[i] | |||
| self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer | |||
| def forward_features(self, x): | |||
| for i in range(4): | |||
| x = self.downsample_layers[i](x) | |||
| x = self.stages[i](x) | |||
| return self.norm(x.mean( | |||
| [-3, -2, -1])) # global average pooling, (N, C, H, W) -> (N, C) | |||
| def forward(self, x): | |||
| if isinstance(x, dict): | |||
| x = x['video'] | |||
| x = self.forward_features(x) | |||
| return x | |||
| def get_num_layers(self): | |||
| return 12, 0 | |||
| class ConvNeXtBlock(nn.Module): | |||
| r""" ConvNeXt Block. There are two equivalent implementations: | |||
| (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) | |||
| (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back | |||
| We use (2) as we find it slightly faster in PyTorch | |||
| Args: | |||
| dim (int): Number of input channels. | |||
| drop_path (float): Stochastic depth rate. Default: 0.0 | |||
| layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. | |||
| """ | |||
| def __init__(self, cfg, dim, drop_path=0., layer_scale_init_value=1e-6): | |||
| super().__init__() | |||
| self.dwconv = nn.Conv3d( | |||
| dim, dim, kernel_size=(1, 7, 7), padding=(0, 3, 3), | |||
| groups=dim) # depthwise conv | |||
| self.norm = LayerNorm(dim, eps=1e-6) | |||
| self.pwconv1 = nn.Linear( | |||
| dim, | |||
| 4 * dim) # pointwise/1x1 convs, implemented with linear layers | |||
| self.act = nn.GELU() | |||
| self.pwconv2 = nn.Linear(4 * dim, dim) | |||
| self.gamma = nn.Parameter( | |||
| layer_scale_init_value * torch.ones((dim)), | |||
| requires_grad=True) if layer_scale_init_value > 0 else None | |||
| self.drop_path = DropPath( | |||
| drop_path) if drop_path > 0. else nn.Identity() | |||
| def forward(self, x): | |||
| input = x | |||
| x = self.dwconv(x) | |||
| x = x.permute(0, 2, 3, 4, 1) # (N, C, T, H, W) -> (N, T, H, W, C) | |||
| x = self.norm(x) | |||
| x = self.pwconv1(x) | |||
| x = self.act(x) | |||
| x = self.pwconv2(x) | |||
| if self.gamma is not None: | |||
| x = self.gamma * x | |||
| x = x.permute(0, 4, 1, 2, 3) # (N, T, H, W, C) -> (N, C, T, H, W) | |||
| x = input + self.drop_path(x) | |||
| return x | |||
| class LayerNorm(nn.Module): | |||
| r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. | |||
| The ordering of the dimensions in the inputs. channels_last corresponds to inputs with | |||
| shape (batch_size, height, width, channels) while channels_first corresponds to inputs | |||
| with shape (batch_size, channels, height, width). | |||
| """ | |||
| def __init__(self, | |||
| normalized_shape, | |||
| eps=1e-6, | |||
| data_format='channels_last'): | |||
| super().__init__() | |||
| self.weight = nn.Parameter(torch.ones(normalized_shape)) | |||
| self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |||
| self.eps = eps | |||
| self.data_format = data_format | |||
| if self.data_format not in ['channels_last', 'channels_first']: | |||
| raise NotImplementedError | |||
| self.normalized_shape = (normalized_shape, ) | |||
| def forward(self, x): | |||
| if self.data_format == 'channels_last': | |||
| return F.layer_norm(x, self.normalized_shape, self.weight, | |||
| self.bias, self.eps) | |||
| elif self.data_format == 'channels_first': | |||
| u = x.mean(1, keepdim=True) | |||
| s = (x - u).pow(2).mean(1, keepdim=True) | |||
| x = (x - u) / torch.sqrt(s + self.eps) | |||
| x = self.weight[:, None, None, None] * x + self.bias[:, None, None, | |||
| None] | |||
| return x | |||
| class TAdaConvNeXtBlock(nn.Module): | |||
| r""" ConvNeXt Block. There are two equivalent implementations: | |||
| (1) DwConv -> LayerNorm (channels_fi rst) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) | |||
| (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back | |||
| We use (2) as we find it slightly faster in PyTorch | |||
| Args: | |||
| dim (int): Number of input channels. | |||
| drop_path (float): Stochastic depth rate. Default: 0.0 | |||
| layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. | |||
| """ | |||
| def __init__(self, cfg, dim, drop_path=0., layer_scale_init_value=1e-6): | |||
| super().__init__() | |||
| layer_scale_init_value = float(layer_scale_init_value) | |||
| self.dwconv = TAdaConv2d( | |||
| dim, | |||
| dim, | |||
| kernel_size=(1, 7, 7), | |||
| padding=(0, 3, 3), | |||
| groups=dim, | |||
| cal_dim='cout') | |||
| route_func_type = cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_TYPE | |||
| if route_func_type == 'normal': | |||
| self.dwconv_rf = RouteFuncMLP( | |||
| c_in=dim, | |||
| ratio=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_R, | |||
| kernels=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_K, | |||
| with_bias_cal=self.dwconv.bias is not None) | |||
| elif route_func_type == 'normal_lngelu': | |||
| self.dwconv_rf = RouteFuncMLPLnGelu( | |||
| c_in=dim, | |||
| ratio=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_R, | |||
| kernels=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_K, | |||
| with_bias_cal=self.dwconv.bias is not None) | |||
| else: | |||
| raise ValueError( | |||
| 'Unknown route_func_type: {}'.format(route_func_type)) | |||
| self.norm = LayerNorm(dim, eps=1e-6) | |||
| self.pwconv1 = nn.Linear( | |||
| dim, | |||
| 4 * dim) # pointwise/1x1 convs, implemented with linear layers | |||
| self.act = nn.GELU() | |||
| self.pwconv2 = nn.Linear(4 * dim, dim) | |||
| self.gamma = nn.Parameter( | |||
| layer_scale_init_value * torch.ones((dim)), | |||
| requires_grad=True) if layer_scale_init_value > 0 else None | |||
| self.drop_path = DropPath( | |||
| drop_path) if drop_path > 0. else nn.Identity() | |||
| def forward(self, x): | |||
| input = x | |||
| x = self.dwconv(x, self.dwconv_rf(x)) | |||
| x = x.permute(0, 2, 3, 4, 1) # (N, C, T, H, W) -> (N, T, H, W, C) | |||
| x = self.norm(x) | |||
| x = self.pwconv1(x) | |||
| x = self.act(x) | |||
| x = self.pwconv2(x) | |||
| if self.gamma is not None: | |||
| x = self.gamma * x | |||
| x = x.permute(0, 4, 1, 2, 3) # (N, T, H, W, C) -> (N, C, T, H, W) | |||
| x = input + self.drop_path(x) | |||
| return x | |||
| class RouteFuncMLPLnGelu(nn.Module): | |||
| """ | |||
| The routing function for generating the calibration weights. | |||
| """ | |||
| def __init__(self, | |||
| c_in, | |||
| ratio, | |||
| kernels, | |||
| with_bias_cal=False, | |||
| bn_eps=1e-5, | |||
| bn_mmt=0.1): | |||
| """ | |||
| Args: | |||
| c_in (int): number of input channels. | |||
| ratio (int): reduction ratio for the routing function. | |||
| kernels (list): temporal kernel size of the stacked 1D convolutions | |||
| """ | |||
| super(RouteFuncMLPLnGelu, self).__init__() | |||
| self.c_in = c_in | |||
| self.with_bias_cal = with_bias_cal | |||
| self.avgpool = nn.AdaptiveAvgPool3d((None, 1, 1)) | |||
| self.globalpool = nn.AdaptiveAvgPool3d(1) | |||
| self.g = nn.Conv3d( | |||
| in_channels=c_in, | |||
| out_channels=c_in, | |||
| kernel_size=1, | |||
| padding=0, | |||
| ) | |||
| self.a = nn.Conv3d( | |||
| in_channels=c_in, | |||
| out_channels=int(c_in // ratio), | |||
| kernel_size=[kernels[0], 1, 1], | |||
| padding=[kernels[0] // 2, 0, 0], | |||
| ) | |||
| # self.bn = nn.BatchNorm3d(int(c_in//ratio), eps=bn_eps, momentum=bn_mmt) | |||
| self.ln = LayerNorm( | |||
| int(c_in // ratio), eps=1e-6, data_format='channels_first') | |||
| self.gelu = nn.GELU() | |||
| # self.relu = nn.ReLU(inplace=True) | |||
| self.b = nn.Conv3d( | |||
| in_channels=int(c_in // ratio), | |||
| out_channels=c_in, | |||
| kernel_size=[kernels[1], 1, 1], | |||
| padding=[kernels[1] // 2, 0, 0], | |||
| bias=False) | |||
| self.b.skip_init = True | |||
| self.b.weight.data.zero_() # to make sure the initial values | |||
| # for the output is 1. | |||
| if with_bias_cal: | |||
| self.b_bias = nn.Conv3d( | |||
| in_channels=int(c_in // ratio), | |||
| out_channels=c_in, | |||
| kernel_size=[kernels[1], 1, 1], | |||
| padding=[kernels[1] // 2, 0, 0], | |||
| bias=False) | |||
| self.b_bias.skip_init = True | |||
| self.b_bias.weight.data.zero_() # to make sure the initial values | |||
| # for the output is 1. | |||
| def forward(self, x): | |||
| g = self.globalpool(x) | |||
| x = self.avgpool(x) | |||
| x = self.a(x + self.g(g)) | |||
| # x = self.bn(x) | |||
| # x = self.relu(x) | |||
| x = self.ln(x) | |||
| x = self.gelu(x) | |||
| if self.with_bias_cal: | |||
| return [self.b(x) + 1, self.b_bias(x) + 1] | |||
| else: | |||
| return self.b(x) + 1 | |||
| class TAdaConv2d(nn.Module): | |||
| """ | |||
| Performs temporally adaptive 2D convolution. | |||
| Currently, only application on 5D tensors is supported, which makes TAdaConv2d | |||
| essentially a 3D convolution with temporal kernel size of 1. | |||
| """ | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=1, | |||
| padding=0, | |||
| dilation=1, | |||
| groups=1, | |||
| bias=True, | |||
| cal_dim='cin'): | |||
| super(TAdaConv2d, self).__init__() | |||
| """ | |||
| Args: | |||
| in_channels (int): number of input channels. | |||
| out_channels (int): number of output channels. | |||
| kernel_size (list): kernel size of TAdaConv2d. | |||
| stride (list): stride for the convolution in TAdaConv2d. | |||
| padding (list): padding for the convolution in TAdaConv2d. | |||
| dilation (list): dilation of the convolution in TAdaConv2d. | |||
| groups (int): number of groups for TAdaConv2d. | |||
| bias (bool): whether to use bias in TAdaConv2d. | |||
| calibration_mode (str): calibrated dimension in TAdaConv2d. | |||
| Supported input "cin", "cout". | |||
| """ | |||
| kernel_size = _triple(kernel_size) | |||
| stride = _triple(stride) | |||
| padding = _triple(padding) | |||
| dilation = _triple(dilation) | |||
| assert kernel_size[0] == 1 | |||
| assert stride[0] == 1 | |||
| assert padding[0] == 0 | |||
| assert dilation[0] == 1 | |||
| assert cal_dim in ['cin', 'cout'] | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| self.kernel_size = kernel_size | |||
| self.stride = stride | |||
| self.padding = padding | |||
| self.dilation = dilation | |||
| self.groups = groups | |||
| self.cal_dim = cal_dim | |||
| # base weights (W_b) | |||
| self.weight = nn.Parameter( | |||
| torch.Tensor(1, 1, out_channels, in_channels // groups, | |||
| kernel_size[1], kernel_size[2])) | |||
| if bias: | |||
| self.bias = nn.Parameter(torch.Tensor(1, 1, out_channels)) | |||
| else: | |||
| self.register_parameter('bias', None) | |||
| nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | |||
| if self.bias is not None: | |||
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) | |||
| bound = 1 / math.sqrt(fan_in) | |||
| nn.init.uniform_(self.bias, -bound, bound) | |||
| def forward(self, x, alpha): | |||
| """ | |||
| Args: | |||
| x (tensor): feature to perform convolution on. | |||
| alpha (tensor): calibration weight for the base weights. | |||
| W_t = alpha_t * W_b | |||
| """ | |||
| if isinstance(alpha, list): | |||
| w_alpha, b_alpha = alpha[0], alpha[1] | |||
| else: | |||
| w_alpha = alpha | |||
| b_alpha = None | |||
| _, _, c_out, c_in, kh, kw = self.weight.size() | |||
| b, c_in, t, h, w = x.size() | |||
| x = x.permute(0, 2, 1, 3, 4).reshape(1, -1, h, w) | |||
| if self.cal_dim == 'cin': | |||
| # w_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, 1, C, H(1), W(1) | |||
| # corresponding to calibrating the input channel | |||
| weight = (w_alpha.permute(0, 2, 1, 3, 4).unsqueeze(2) | |||
| * self.weight).reshape(-1, c_in // self.groups, kh, kw) | |||
| elif self.cal_dim == 'cout': | |||
| # w_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, C, 1, H(1), W(1) | |||
| # corresponding to calibrating the input channel | |||
| weight = (w_alpha.permute(0, 2, 1, 3, 4).unsqueeze(3) | |||
| * self.weight).reshape(-1, c_in // self.groups, kh, kw) | |||
| bias = None | |||
| if self.bias is not None: | |||
| if b_alpha is not None: | |||
| # b_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, C | |||
| bias = (b_alpha.permute(0, 2, 1, 3, 4).squeeze() | |||
| * self.bias).reshape(-1) | |||
| else: | |||
| bias = self.bias.repeat(b, t, 1).reshape(-1) | |||
| output = F.conv2d( | |||
| x, | |||
| weight=weight, | |||
| bias=bias, | |||
| stride=self.stride[1:], | |||
| padding=self.padding[1:], | |||
| dilation=self.dilation[1:], | |||
| groups=self.groups * b * t) | |||
| output = output.view(b, t, c_out, output.size(-2), | |||
| output.size(-1)).permute(0, 2, 1, 3, 4) | |||
| return output | |||
| def __repr__(self): | |||
| return f'TAdaConv2d({self.in_channels}, {self.out_channels}, kernel_size={self.kernel_size}, ' +\ | |||
| f"stride={self.stride}, padding={self.padding}, bias={self.bias is not None}, cal_dim=\"{self.cal_dim}\")" | |||
| @@ -37,6 +37,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/cv_unet_person-image-cartoon_compound-models'), | |||
| Tasks.ocr_detection: (Pipelines.ocr_detection, | |||
| 'damo/cv_resnet18_ocr-detection-line-level_damo'), | |||
| Tasks.action_recognition: (Pipelines.action_recognition, | |||
| 'damo/cv_TAdaConv_action-recognition'), | |||
| } | |||
| @@ -1,3 +1,4 @@ | |||
| from .action_recognition_pipeline import ActionRecognitionPipeline | |||
| from .image_cartoon_pipeline import ImageCartoonPipeline | |||
| from .image_matting_pipeline import ImageMattingPipeline | |||
| from .ocr_detection_pipeline import OCRDetectionPipeline | |||
| @@ -0,0 +1,65 @@ | |||
| import math | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.action_recognition.models import BaseVideoModel | |||
| from modelscope.pipelines.base import Input | |||
| from modelscope.preprocessors.video import ReadVideoData | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from ..base import Pipeline | |||
| from ..builder import PIPELINES | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.action_recognition, module_name=Pipelines.action_recognition) | |||
| class ActionRecognitionPipeline(Pipeline): | |||
| def __init__(self, model: str): | |||
| super().__init__(model=model) | |||
| model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | |||
| logger.info(f'loading model from {model_path}') | |||
| config_path = osp.join(self.model, ModelFile.CONFIGURATION) | |||
| logger.info(f'loading config from {config_path}') | |||
| self.cfg = Config.from_file(config_path) | |||
| self.infer_model = BaseVideoModel(cfg=self.cfg).cuda() | |||
| self.infer_model.eval() | |||
| self.infer_model.load_state_dict(torch.load(model_path)['model_state']) | |||
| self.label_mapping = self.cfg.label_mapping | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| if isinstance(input, str): | |||
| video_input_data = ReadVideoData(self.cfg, input).cuda() | |||
| else: | |||
| raise TypeError(f'input should be a str,' | |||
| f' but got {type(input)}') | |||
| result = {'video_data': video_input_data} | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| pred = self.perform_inference(input['video_data']) | |||
| output_label = self.label_mapping[str(pred)] | |||
| return {'output_label': output_label} | |||
| @torch.no_grad() | |||
| def perform_inference(self, data, max_bsz=4): | |||
| iter_num = math.ceil(data.size(0) / max_bsz) | |||
| preds_list = [] | |||
| for i in range(iter_num): | |||
| preds_list.append( | |||
| self.infer_model(data[i * max_bsz:(i + 1) * max_bsz])[0]) | |||
| pred = torch.cat(preds_list, dim=0) | |||
| return pred.mean(dim=0).argmax().item() | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -45,6 +45,12 @@ TASK_OUTPUTS = { | |||
| Tasks.image_matting: ['output_png'], | |||
| Tasks.image_generation: ['output_png'], | |||
| # action recognition result for single video | |||
| # { | |||
| # "output_label": "abseiling" | |||
| # } | |||
| Tasks.action_recognition: ['output_label'], | |||
| # pose estimation result for single sample | |||
| # { | |||
| # "poses": np.array with shape [num_pose, num_keypoint, 3], | |||
| @@ -0,0 +1,232 @@ | |||
| import math | |||
| import os | |||
| import random | |||
| import decord | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.utils.data | |||
| import torch.utils.dlpack as dlpack | |||
| import torchvision.transforms._transforms_video as transforms | |||
| from decord import VideoReader | |||
| from torchvision.transforms import Compose | |||
| def ReadVideoData(cfg, video_path): | |||
| """ simple interface to load video frames from file | |||
| Args: | |||
| cfg (Config): The global config object. | |||
| video_path (str): video file path | |||
| """ | |||
| data = _decode_video(cfg, video_path) | |||
| transform = kinetics400_tranform(cfg) | |||
| data_list = [] | |||
| for i in range(data.size(0)): | |||
| for j in range(cfg.TEST.NUM_SPATIAL_CROPS): | |||
| transform.transforms[1].set_spatial_index(j) | |||
| data_list.append(transform(data[i])) | |||
| return torch.stack(data_list, dim=0) | |||
| def kinetics400_tranform(cfg): | |||
| """ | |||
| Configs the transform for the kinetics-400 dataset. | |||
| We apply controlled spatial cropping and normalization. | |||
| Args: | |||
| cfg (Config): The global config object. | |||
| """ | |||
| resize_video = KineticsResizedCrop( | |||
| short_side_range=[cfg.DATA.TEST_SCALE, cfg.DATA.TEST_SCALE], | |||
| crop_size=cfg.DATA.TEST_CROP_SIZE, | |||
| num_spatial_crops=cfg.TEST.NUM_SPATIAL_CROPS) | |||
| std_transform_list = [ | |||
| transforms.ToTensorVideo(), resize_video, | |||
| transforms.NormalizeVideo( | |||
| mean=cfg.DATA.MEAN, std=cfg.DATA.STD, inplace=True) | |||
| ] | |||
| return Compose(std_transform_list) | |||
| def _interval_based_sampling(vid_length, vid_fps, target_fps, clip_idx, | |||
| num_clips, num_frames, interval, minus_interval): | |||
| """ | |||
| Generates the frame index list using interval based sampling. | |||
| Args: | |||
| vid_length (int): the length of the whole video (valid selection range). | |||
| vid_fps (int): the original video fps | |||
| target_fps (int): the normalized video fps | |||
| clip_idx (int): -1 for random temporal sampling, and positive values for | |||
| sampling specific clip from the video | |||
| num_clips (int): the total clips to be sampled from each video. | |||
| combined with clip_idx, the sampled video is the "clip_idx-th" | |||
| video from "num_clips" videos. | |||
| num_frames (int): number of frames in each sampled clips. | |||
| interval (int): the interval to sample each frame. | |||
| minus_interval (bool): control the end index | |||
| Returns: | |||
| index (tensor): the sampled frame indexes | |||
| """ | |||
| if num_frames == 1: | |||
| index = [random.randint(0, vid_length - 1)] | |||
| else: | |||
| # transform FPS | |||
| clip_length = num_frames * interval * vid_fps / target_fps | |||
| max_idx = max(vid_length - clip_length, 0) | |||
| start_idx = clip_idx * math.floor(max_idx / (num_clips - 1)) | |||
| if minus_interval: | |||
| end_idx = start_idx + clip_length - interval | |||
| else: | |||
| end_idx = start_idx + clip_length - 1 | |||
| index = torch.linspace(start_idx, end_idx, num_frames) | |||
| index = torch.clamp(index, 0, vid_length - 1).long() | |||
| return index | |||
| def _decode_video_frames_list(cfg, frames_list, vid_fps): | |||
| """ | |||
| Decodes the video given the numpy frames. | |||
| Args: | |||
| cfg (Config): The global config object. | |||
| frames_list (list): all frames for a video, the frames should be numpy array. | |||
| vid_fps (int): the fps of this video. | |||
| Returns: | |||
| frames (Tensor): video tensor data | |||
| """ | |||
| assert isinstance(frames_list, list) | |||
| num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS | |||
| frame_list = [] | |||
| for clip_idx in range(num_clips_per_video): | |||
| # for each clip in the video, | |||
| # a list is generated before decoding the specified frames from the video | |||
| list_ = _interval_based_sampling( | |||
| len(frames_list), vid_fps, cfg.DATA.TARGET_FPS, clip_idx, | |||
| num_clips_per_video, cfg.DATA.NUM_INPUT_FRAMES, | |||
| cfg.DATA.SAMPLING_RATE, cfg.DATA.MINUS_INTERVAL) | |||
| frames = None | |||
| frames = torch.from_numpy( | |||
| np.stack([frames_list[l_index] for l_index in list_.tolist()], | |||
| axis=0)) | |||
| frame_list.append(frames) | |||
| frames = torch.stack(frame_list) | |||
| if num_clips_per_video == 1: | |||
| frames = frames.squeeze(0) | |||
| return frames | |||
| def _decode_video(cfg, path): | |||
| """ | |||
| Decodes the video given the numpy frames. | |||
| Args: | |||
| path (str): video file path. | |||
| Returns: | |||
| frames (Tensor): video tensor data | |||
| """ | |||
| vr = VideoReader(path) | |||
| num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS | |||
| frame_list = [] | |||
| for clip_idx in range(num_clips_per_video): | |||
| # for each clip in the video, | |||
| # a list is generated before decoding the specified frames from the video | |||
| list_ = _interval_based_sampling( | |||
| len(vr), vr.get_avg_fps(), cfg.DATA.TARGET_FPS, clip_idx, | |||
| num_clips_per_video, cfg.DATA.NUM_INPUT_FRAMES, | |||
| cfg.DATA.SAMPLING_RATE, cfg.DATA.MINUS_INTERVAL) | |||
| frames = None | |||
| if path.endswith('.avi'): | |||
| append_list = torch.arange(0, list_[0], 4) | |||
| frames = dlpack.from_dlpack( | |||
| vr.get_batch(torch.cat([append_list, | |||
| list_])).to_dlpack()).clone() | |||
| frames = frames[append_list.shape[0]:] | |||
| else: | |||
| frames = dlpack.from_dlpack( | |||
| vr.get_batch(list_).to_dlpack()).clone() | |||
| frame_list.append(frames) | |||
| frames = torch.stack(frame_list) | |||
| if num_clips_per_video == 1: | |||
| frames = frames.squeeze(0) | |||
| del vr | |||
| return frames | |||
| class KineticsResizedCrop(object): | |||
| """Perform resize and crop for kinetics-400 dataset | |||
| Args: | |||
| short_side_range (list): The length of short side range. In inference, this shoudle be [256, 256] | |||
| crop_size (int): The cropped size for frames. | |||
| num_spatial_crops (int): The number of the cropped spatial regions in each video. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| short_side_range, | |||
| crop_size, | |||
| num_spatial_crops=1, | |||
| ): | |||
| self.idx = -1 | |||
| self.short_side_range = short_side_range | |||
| self.crop_size = int(crop_size) | |||
| self.num_spatial_crops = num_spatial_crops | |||
| def _get_controlled_crop(self, clip): | |||
| """Perform controlled crop for video tensor. | |||
| Args: | |||
| clip (Tensor): the video data, the shape is [T, C, H, W] | |||
| """ | |||
| _, _, clip_height, clip_width = clip.shape | |||
| length = self.short_side_range[0] | |||
| if clip_height < clip_width: | |||
| new_clip_height = int(length) | |||
| new_clip_width = int(clip_width / clip_height * new_clip_height) | |||
| new_clip = torch.nn.functional.interpolate( | |||
| clip, size=(new_clip_height, new_clip_width), mode='bilinear') | |||
| else: | |||
| new_clip_width = int(length) | |||
| new_clip_height = int(clip_height / clip_width * new_clip_width) | |||
| new_clip = torch.nn.functional.interpolate( | |||
| clip, size=(new_clip_height, new_clip_width), mode='bilinear') | |||
| x_max = int(new_clip_width - self.crop_size) | |||
| y_max = int(new_clip_height - self.crop_size) | |||
| if self.num_spatial_crops == 1: | |||
| x = x_max // 2 | |||
| y = y_max // 2 | |||
| elif self.num_spatial_crops == 3: | |||
| if self.idx == 0: | |||
| if new_clip_width == length: | |||
| x = x_max // 2 | |||
| y = 0 | |||
| elif new_clip_height == length: | |||
| x = 0 | |||
| y = y_max // 2 | |||
| elif self.idx == 1: | |||
| x = x_max // 2 | |||
| y = y_max // 2 | |||
| elif self.idx == 2: | |||
| if new_clip_width == length: | |||
| x = x_max // 2 | |||
| y = y_max | |||
| elif new_clip_height == length: | |||
| x = x_max | |||
| y = y_max // 2 | |||
| return new_clip[:, :, y:y + self.crop_size, x:x + self.crop_size] | |||
| def set_spatial_index(self, idx): | |||
| """Set the spatial cropping index for controlled cropping.. | |||
| Args: | |||
| idx (int): the spatial index. The value should be in [0, 1, 2], means [left, center, right], respectively. | |||
| """ | |||
| self.idx = idx | |||
| def __call__(self, clip): | |||
| return self._get_controlled_crop(clip) | |||
| @@ -29,6 +29,7 @@ class Tasks(object): | |||
| image_generation = 'image-generation' | |||
| image_matting = 'image-matting' | |||
| ocr_detection = 'ocr-detection' | |||
| action_recognition = 'action-recognition' | |||
| # nlp tasks | |||
| word_segmentation = 'word-segmentation' | |||
| @@ -1,2 +1,3 @@ | |||
| decord>=0.6.0 | |||
| easydict | |||
| tf_slim | |||
| @@ -0,0 +1,58 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # !/usr/bin/env python | |||
| import os.path as osp | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| import cv2 | |||
| from modelscope.fileio import File | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pydatasets import PyDataset | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class ActionRecognitionTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_TAdaConv_action-recognition' | |||
| @unittest.skip('deprecated, download model from model hub instead') | |||
| def test_run_with_direct_file_download(self): | |||
| model_path = 'https://aquila2-online-models.oss-cn-shanghai.aliyuncs.com/maas_test/pytorch_model.pt' | |||
| config_path = 'https://aquila2-online-models.oss-cn-shanghai.aliyuncs.com/maas_test/configuration.json' | |||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||
| model_file = osp.join(tmp_dir, ModelFile.TORCH_MODEL_FILE) | |||
| with open(model_file, 'wb') as ofile1: | |||
| ofile1.write(File.read(model_path)) | |||
| config_file = osp.join(tmp_dir, ModelFile.CONFIGURATION) | |||
| with open(config_file, 'wb') as ofile2: | |||
| ofile2.write(File.read(config_path)) | |||
| recognition_pipeline = pipeline( | |||
| Tasks.action_recognition, model=tmp_dir) | |||
| result = recognition_pipeline( | |||
| 'data/test/videos/action_recognition_test_video.mp4') | |||
| print(f'recognition output: {result}.') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| recognition_pipeline = pipeline( | |||
| Tasks.action_recognition, model=self.model_id) | |||
| result = recognition_pipeline( | |||
| 'data/test/videos/action_recognition_test_video.mp4') | |||
| print(f'recognition output: {result}.') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub_default_model(self): | |||
| recognition_pipeline = pipeline(Tasks.action_recognition) | |||
| result = recognition_pipeline( | |||
| 'data/test/videos/action_recognition_test_video.mp4') | |||
| print(f'recognition output: {result}.') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||