* feat/nlp: [to #42322933] Add cv-action-recongnition-pipeline to maas lib [to #42463204] support Pil.Image for image_captioning_pipeline [to #42670107] restore pydataset test [to #42322933] add create if not exist and add(back) create model example Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9130661 [to #41474818]fix: fix errors in task name definition # Conflicts: # modelscope/pipelines/builder.pymaster
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:24dc4237b1197321ee8486bb983fa01fd47e2b4afdb3c2df24229e5f2bd20119 | |||||
| size 1475924 | |||||
| @@ -40,6 +40,7 @@ class Pipelines(object): | |||||
| image_matting = 'unet-image-matting' | image_matting = 'unet-image-matting' | ||||
| person_image_cartoon = 'unet-person-image-cartoon' | person_image_cartoon = 'unet-person-image-cartoon' | ||||
| ocr_detection = 'resnet18-ocr-detection' | ocr_detection = 'resnet18-ocr-detection' | ||||
| action_recognition = 'TAdaConv_action-recognition' | |||||
| # nlp tasks | # nlp tasks | ||||
| sentence_similarity = 'sentence-similarity' | sentence_similarity = 'sentence-similarity' | ||||
| @@ -0,0 +1,91 @@ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from .tada_convnext import TadaConvNeXt | |||||
| class BaseVideoModel(nn.Module): | |||||
| """ | |||||
| Standard video model. | |||||
| The model is divided into the backbone and the head, where the backbone | |||||
| extracts features and the head performs classification. | |||||
| The backbones can be defined in model/base/backbone.py or anywhere else | |||||
| as long as the backbone is registered by the BACKBONE_REGISTRY. | |||||
| The heads can be defined in model/module_zoo/heads/ or anywhere else | |||||
| as long as the head is registered by the HEAD_REGISTRY. | |||||
| The registries automatically finds the registered modules and construct | |||||
| the base video model. | |||||
| """ | |||||
| def __init__(self, cfg): | |||||
| """ | |||||
| Args: | |||||
| cfg (Config): global config object. | |||||
| """ | |||||
| super(BaseVideoModel, self).__init__() | |||||
| # the backbone is created according to meta-architectures | |||||
| # defined in models/base/backbone.py | |||||
| self.backbone = TadaConvNeXt(cfg) | |||||
| # the head is created according to the heads | |||||
| # defined in models/module_zoo/heads | |||||
| self.head = BaseHead(cfg) | |||||
| def forward(self, x): | |||||
| x = self.backbone(x) | |||||
| x = self.head(x) | |||||
| return x | |||||
| class BaseHead(nn.Module): | |||||
| """ | |||||
| Constructs base head. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| cfg, | |||||
| ): | |||||
| """ | |||||
| Args: | |||||
| cfg (Config): global config object. | |||||
| """ | |||||
| super(BaseHead, self).__init__() | |||||
| self.cfg = cfg | |||||
| dim = cfg.VIDEO.BACKBONE.NUM_OUT_FEATURES | |||||
| num_classes = cfg.VIDEO.HEAD.NUM_CLASSES | |||||
| dropout_rate = cfg.VIDEO.HEAD.DROPOUT_RATE | |||||
| activation_func = cfg.VIDEO.HEAD.ACTIVATION | |||||
| self._construct_head(dim, num_classes, dropout_rate, activation_func) | |||||
| def _construct_head(self, dim, num_classes, dropout_rate, activation_func): | |||||
| self.global_avg_pool = nn.AdaptiveAvgPool3d(1) | |||||
| if dropout_rate > 0.0: | |||||
| self.dropout = nn.Dropout(dropout_rate) | |||||
| self.out = nn.Linear(dim, num_classes, bias=True) | |||||
| if activation_func == 'softmax': | |||||
| self.activation = nn.Softmax(dim=-1) | |||||
| elif activation_func == 'sigmoid': | |||||
| self.activation = nn.Sigmoid() | |||||
| else: | |||||
| raise NotImplementedError('{} is not supported as an activation' | |||||
| 'function.'.format(activation_func)) | |||||
| def forward(self, x): | |||||
| if len(x.shape) == 5: | |||||
| x = self.global_avg_pool(x) | |||||
| # (N, C, T, H, W) -> (N, T, H, W, C). | |||||
| x = x.permute((0, 2, 3, 4, 1)) | |||||
| if hasattr(self, 'dropout'): | |||||
| out = self.dropout(x) | |||||
| else: | |||||
| out = x | |||||
| out = self.out(out) | |||||
| out = self.activation(out) | |||||
| out = out.view(out.shape[0], -1) | |||||
| return out, x.view(x.shape[0], -1) | |||||
| @@ -0,0 +1,472 @@ | |||||
| import math | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from torch.nn.modules.utils import _pair, _triple | |||||
| def drop_path(x, drop_prob: float = 0., training: bool = False): | |||||
| """ | |||||
| From https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py. | |||||
| Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||||
| This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | |||||
| the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... | |||||
| See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | |||||
| changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use | |||||
| 'survival rate' as the argument. | |||||
| """ | |||||
| if drop_prob == 0. or not training: | |||||
| return x | |||||
| keep_prob = 1 - drop_prob | |||||
| shape = (x.shape[0], ) + (1, ) * ( | |||||
| x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | |||||
| random_tensor = keep_prob + torch.rand( | |||||
| shape, dtype=x.dtype, device=x.device) | |||||
| random_tensor.floor_() # binarize | |||||
| output = x.div(keep_prob) * random_tensor | |||||
| return output | |||||
| class DropPath(nn.Module): | |||||
| """ | |||||
| From https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py. | |||||
| Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||||
| """ | |||||
| def __init__(self, drop_prob=None): | |||||
| super(DropPath, self).__init__() | |||||
| self.drop_prob = drop_prob | |||||
| def forward(self, x): | |||||
| return drop_path(x, self.drop_prob, self.training) | |||||
| class TadaConvNeXt(nn.Module): | |||||
| r""" ConvNeXt | |||||
| A PyTorch impl of : `A ConvNet for the 2020s` - | |||||
| https://arxiv.org/pdf/2201.03545.pdf | |||||
| Args: | |||||
| in_chans (int): Number of input image channels. Default: 3 | |||||
| num_classes (int): Number of classes for classification head. Default: 1000 | |||||
| depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] | |||||
| dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] | |||||
| drop_path_rate (float): Stochastic depth rate. Default: 0. | |||||
| layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. | |||||
| head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. | |||||
| """ | |||||
| def __init__( | |||||
| self, cfg | |||||
| # in_chans=3, num_classes=1000, | |||||
| # depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., | |||||
| # layer_scale_init_value=1e-6, head_init_scale=1., | |||||
| ): | |||||
| super().__init__() | |||||
| in_chans = cfg.VIDEO.BACKBONE.NUM_INPUT_CHANNELS | |||||
| dims = cfg.VIDEO.BACKBONE.NUM_FILTERS | |||||
| drop_path_rate = cfg.VIDEO.BACKBONE.DROP_PATH | |||||
| depths = cfg.VIDEO.BACKBONE.DEPTH | |||||
| layer_scale_init_value = cfg.VIDEO.BACKBONE.LARGE_SCALE_INIT_VALUE | |||||
| stem_t_kernel_size = cfg.VIDEO.BACKBONE.STEM.T_KERNEL_SIZE if hasattr( | |||||
| cfg.VIDEO.BACKBONE.STEM, 'T_KERNEL_SIZE') else 2 | |||||
| t_stride = cfg.VIDEO.BACKBONE.STEM.T_STRIDE if hasattr( | |||||
| cfg.VIDEO.BACKBONE.STEM, 'T_STRIDE') else 2 | |||||
| self.downsample_layers = nn.ModuleList( | |||||
| ) # stem and 3 intermediate downsampling conv layers | |||||
| stem = nn.Sequential( | |||||
| nn.Conv3d( | |||||
| in_chans, | |||||
| dims[0], | |||||
| kernel_size=(stem_t_kernel_size, 4, 4), | |||||
| stride=(t_stride, 4, 4), | |||||
| padding=((stem_t_kernel_size - 1) // 2, 0, 0)), | |||||
| LayerNorm(dims[0], eps=1e-6, data_format='channels_first')) | |||||
| self.downsample_layers.append(stem) | |||||
| for i in range(3): | |||||
| downsample_layer = nn.Sequential( | |||||
| LayerNorm(dims[i], eps=1e-6, data_format='channels_first'), | |||||
| nn.Conv3d( | |||||
| dims[i], | |||||
| dims[i + 1], | |||||
| kernel_size=(1, 2, 2), | |||||
| stride=(1, 2, 2)), | |||||
| ) | |||||
| self.downsample_layers.append(downsample_layer) | |||||
| self.stages = nn.ModuleList( | |||||
| ) # 4 feature resolution stages, each consisting of multiple residual blocks | |||||
| dp_rates = [ | |||||
| x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) | |||||
| ] | |||||
| cur = 0 | |||||
| for i in range(4): | |||||
| stage = nn.Sequential(*[ | |||||
| TAdaConvNeXtBlock( | |||||
| cfg, | |||||
| dim=dims[i], | |||||
| drop_path=dp_rates[cur + j], | |||||
| layer_scale_init_value=layer_scale_init_value) | |||||
| for j in range(depths[i]) | |||||
| ]) | |||||
| self.stages.append(stage) | |||||
| cur += depths[i] | |||||
| self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer | |||||
| def forward_features(self, x): | |||||
| for i in range(4): | |||||
| x = self.downsample_layers[i](x) | |||||
| x = self.stages[i](x) | |||||
| return self.norm(x.mean( | |||||
| [-3, -2, -1])) # global average pooling, (N, C, H, W) -> (N, C) | |||||
| def forward(self, x): | |||||
| if isinstance(x, dict): | |||||
| x = x['video'] | |||||
| x = self.forward_features(x) | |||||
| return x | |||||
| def get_num_layers(self): | |||||
| return 12, 0 | |||||
| class ConvNeXtBlock(nn.Module): | |||||
| r""" ConvNeXt Block. There are two equivalent implementations: | |||||
| (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) | |||||
| (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back | |||||
| We use (2) as we find it slightly faster in PyTorch | |||||
| Args: | |||||
| dim (int): Number of input channels. | |||||
| drop_path (float): Stochastic depth rate. Default: 0.0 | |||||
| layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. | |||||
| """ | |||||
| def __init__(self, cfg, dim, drop_path=0., layer_scale_init_value=1e-6): | |||||
| super().__init__() | |||||
| self.dwconv = nn.Conv3d( | |||||
| dim, dim, kernel_size=(1, 7, 7), padding=(0, 3, 3), | |||||
| groups=dim) # depthwise conv | |||||
| self.norm = LayerNorm(dim, eps=1e-6) | |||||
| self.pwconv1 = nn.Linear( | |||||
| dim, | |||||
| 4 * dim) # pointwise/1x1 convs, implemented with linear layers | |||||
| self.act = nn.GELU() | |||||
| self.pwconv2 = nn.Linear(4 * dim, dim) | |||||
| self.gamma = nn.Parameter( | |||||
| layer_scale_init_value * torch.ones((dim)), | |||||
| requires_grad=True) if layer_scale_init_value > 0 else None | |||||
| self.drop_path = DropPath( | |||||
| drop_path) if drop_path > 0. else nn.Identity() | |||||
| def forward(self, x): | |||||
| input = x | |||||
| x = self.dwconv(x) | |||||
| x = x.permute(0, 2, 3, 4, 1) # (N, C, T, H, W) -> (N, T, H, W, C) | |||||
| x = self.norm(x) | |||||
| x = self.pwconv1(x) | |||||
| x = self.act(x) | |||||
| x = self.pwconv2(x) | |||||
| if self.gamma is not None: | |||||
| x = self.gamma * x | |||||
| x = x.permute(0, 4, 1, 2, 3) # (N, T, H, W, C) -> (N, C, T, H, W) | |||||
| x = input + self.drop_path(x) | |||||
| return x | |||||
| class LayerNorm(nn.Module): | |||||
| r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. | |||||
| The ordering of the dimensions in the inputs. channels_last corresponds to inputs with | |||||
| shape (batch_size, height, width, channels) while channels_first corresponds to inputs | |||||
| with shape (batch_size, channels, height, width). | |||||
| """ | |||||
| def __init__(self, | |||||
| normalized_shape, | |||||
| eps=1e-6, | |||||
| data_format='channels_last'): | |||||
| super().__init__() | |||||
| self.weight = nn.Parameter(torch.ones(normalized_shape)) | |||||
| self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |||||
| self.eps = eps | |||||
| self.data_format = data_format | |||||
| if self.data_format not in ['channels_last', 'channels_first']: | |||||
| raise NotImplementedError | |||||
| self.normalized_shape = (normalized_shape, ) | |||||
| def forward(self, x): | |||||
| if self.data_format == 'channels_last': | |||||
| return F.layer_norm(x, self.normalized_shape, self.weight, | |||||
| self.bias, self.eps) | |||||
| elif self.data_format == 'channels_first': | |||||
| u = x.mean(1, keepdim=True) | |||||
| s = (x - u).pow(2).mean(1, keepdim=True) | |||||
| x = (x - u) / torch.sqrt(s + self.eps) | |||||
| x = self.weight[:, None, None, None] * x + self.bias[:, None, None, | |||||
| None] | |||||
| return x | |||||
| class TAdaConvNeXtBlock(nn.Module): | |||||
| r""" ConvNeXt Block. There are two equivalent implementations: | |||||
| (1) DwConv -> LayerNorm (channels_fi rst) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) | |||||
| (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back | |||||
| We use (2) as we find it slightly faster in PyTorch | |||||
| Args: | |||||
| dim (int): Number of input channels. | |||||
| drop_path (float): Stochastic depth rate. Default: 0.0 | |||||
| layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. | |||||
| """ | |||||
| def __init__(self, cfg, dim, drop_path=0., layer_scale_init_value=1e-6): | |||||
| super().__init__() | |||||
| layer_scale_init_value = float(layer_scale_init_value) | |||||
| self.dwconv = TAdaConv2d( | |||||
| dim, | |||||
| dim, | |||||
| kernel_size=(1, 7, 7), | |||||
| padding=(0, 3, 3), | |||||
| groups=dim, | |||||
| cal_dim='cout') | |||||
| route_func_type = cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_TYPE | |||||
| if route_func_type == 'normal': | |||||
| self.dwconv_rf = RouteFuncMLP( | |||||
| c_in=dim, | |||||
| ratio=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_R, | |||||
| kernels=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_K, | |||||
| with_bias_cal=self.dwconv.bias is not None) | |||||
| elif route_func_type == 'normal_lngelu': | |||||
| self.dwconv_rf = RouteFuncMLPLnGelu( | |||||
| c_in=dim, | |||||
| ratio=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_R, | |||||
| kernels=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_K, | |||||
| with_bias_cal=self.dwconv.bias is not None) | |||||
| else: | |||||
| raise ValueError( | |||||
| 'Unknown route_func_type: {}'.format(route_func_type)) | |||||
| self.norm = LayerNorm(dim, eps=1e-6) | |||||
| self.pwconv1 = nn.Linear( | |||||
| dim, | |||||
| 4 * dim) # pointwise/1x1 convs, implemented with linear layers | |||||
| self.act = nn.GELU() | |||||
| self.pwconv2 = nn.Linear(4 * dim, dim) | |||||
| self.gamma = nn.Parameter( | |||||
| layer_scale_init_value * torch.ones((dim)), | |||||
| requires_grad=True) if layer_scale_init_value > 0 else None | |||||
| self.drop_path = DropPath( | |||||
| drop_path) if drop_path > 0. else nn.Identity() | |||||
| def forward(self, x): | |||||
| input = x | |||||
| x = self.dwconv(x, self.dwconv_rf(x)) | |||||
| x = x.permute(0, 2, 3, 4, 1) # (N, C, T, H, W) -> (N, T, H, W, C) | |||||
| x = self.norm(x) | |||||
| x = self.pwconv1(x) | |||||
| x = self.act(x) | |||||
| x = self.pwconv2(x) | |||||
| if self.gamma is not None: | |||||
| x = self.gamma * x | |||||
| x = x.permute(0, 4, 1, 2, 3) # (N, T, H, W, C) -> (N, C, T, H, W) | |||||
| x = input + self.drop_path(x) | |||||
| return x | |||||
| class RouteFuncMLPLnGelu(nn.Module): | |||||
| """ | |||||
| The routing function for generating the calibration weights. | |||||
| """ | |||||
| def __init__(self, | |||||
| c_in, | |||||
| ratio, | |||||
| kernels, | |||||
| with_bias_cal=False, | |||||
| bn_eps=1e-5, | |||||
| bn_mmt=0.1): | |||||
| """ | |||||
| Args: | |||||
| c_in (int): number of input channels. | |||||
| ratio (int): reduction ratio for the routing function. | |||||
| kernels (list): temporal kernel size of the stacked 1D convolutions | |||||
| """ | |||||
| super(RouteFuncMLPLnGelu, self).__init__() | |||||
| self.c_in = c_in | |||||
| self.with_bias_cal = with_bias_cal | |||||
| self.avgpool = nn.AdaptiveAvgPool3d((None, 1, 1)) | |||||
| self.globalpool = nn.AdaptiveAvgPool3d(1) | |||||
| self.g = nn.Conv3d( | |||||
| in_channels=c_in, | |||||
| out_channels=c_in, | |||||
| kernel_size=1, | |||||
| padding=0, | |||||
| ) | |||||
| self.a = nn.Conv3d( | |||||
| in_channels=c_in, | |||||
| out_channels=int(c_in // ratio), | |||||
| kernel_size=[kernels[0], 1, 1], | |||||
| padding=[kernels[0] // 2, 0, 0], | |||||
| ) | |||||
| # self.bn = nn.BatchNorm3d(int(c_in//ratio), eps=bn_eps, momentum=bn_mmt) | |||||
| self.ln = LayerNorm( | |||||
| int(c_in // ratio), eps=1e-6, data_format='channels_first') | |||||
| self.gelu = nn.GELU() | |||||
| # self.relu = nn.ReLU(inplace=True) | |||||
| self.b = nn.Conv3d( | |||||
| in_channels=int(c_in // ratio), | |||||
| out_channels=c_in, | |||||
| kernel_size=[kernels[1], 1, 1], | |||||
| padding=[kernels[1] // 2, 0, 0], | |||||
| bias=False) | |||||
| self.b.skip_init = True | |||||
| self.b.weight.data.zero_() # to make sure the initial values | |||||
| # for the output is 1. | |||||
| if with_bias_cal: | |||||
| self.b_bias = nn.Conv3d( | |||||
| in_channels=int(c_in // ratio), | |||||
| out_channels=c_in, | |||||
| kernel_size=[kernels[1], 1, 1], | |||||
| padding=[kernels[1] // 2, 0, 0], | |||||
| bias=False) | |||||
| self.b_bias.skip_init = True | |||||
| self.b_bias.weight.data.zero_() # to make sure the initial values | |||||
| # for the output is 1. | |||||
| def forward(self, x): | |||||
| g = self.globalpool(x) | |||||
| x = self.avgpool(x) | |||||
| x = self.a(x + self.g(g)) | |||||
| # x = self.bn(x) | |||||
| # x = self.relu(x) | |||||
| x = self.ln(x) | |||||
| x = self.gelu(x) | |||||
| if self.with_bias_cal: | |||||
| return [self.b(x) + 1, self.b_bias(x) + 1] | |||||
| else: | |||||
| return self.b(x) + 1 | |||||
| class TAdaConv2d(nn.Module): | |||||
| """ | |||||
| Performs temporally adaptive 2D convolution. | |||||
| Currently, only application on 5D tensors is supported, which makes TAdaConv2d | |||||
| essentially a 3D convolution with temporal kernel size of 1. | |||||
| """ | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride=1, | |||||
| padding=0, | |||||
| dilation=1, | |||||
| groups=1, | |||||
| bias=True, | |||||
| cal_dim='cin'): | |||||
| super(TAdaConv2d, self).__init__() | |||||
| """ | |||||
| Args: | |||||
| in_channels (int): number of input channels. | |||||
| out_channels (int): number of output channels. | |||||
| kernel_size (list): kernel size of TAdaConv2d. | |||||
| stride (list): stride for the convolution in TAdaConv2d. | |||||
| padding (list): padding for the convolution in TAdaConv2d. | |||||
| dilation (list): dilation of the convolution in TAdaConv2d. | |||||
| groups (int): number of groups for TAdaConv2d. | |||||
| bias (bool): whether to use bias in TAdaConv2d. | |||||
| calibration_mode (str): calibrated dimension in TAdaConv2d. | |||||
| Supported input "cin", "cout". | |||||
| """ | |||||
| kernel_size = _triple(kernel_size) | |||||
| stride = _triple(stride) | |||||
| padding = _triple(padding) | |||||
| dilation = _triple(dilation) | |||||
| assert kernel_size[0] == 1 | |||||
| assert stride[0] == 1 | |||||
| assert padding[0] == 0 | |||||
| assert dilation[0] == 1 | |||||
| assert cal_dim in ['cin', 'cout'] | |||||
| self.in_channels = in_channels | |||||
| self.out_channels = out_channels | |||||
| self.kernel_size = kernel_size | |||||
| self.stride = stride | |||||
| self.padding = padding | |||||
| self.dilation = dilation | |||||
| self.groups = groups | |||||
| self.cal_dim = cal_dim | |||||
| # base weights (W_b) | |||||
| self.weight = nn.Parameter( | |||||
| torch.Tensor(1, 1, out_channels, in_channels // groups, | |||||
| kernel_size[1], kernel_size[2])) | |||||
| if bias: | |||||
| self.bias = nn.Parameter(torch.Tensor(1, 1, out_channels)) | |||||
| else: | |||||
| self.register_parameter('bias', None) | |||||
| nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | |||||
| if self.bias is not None: | |||||
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) | |||||
| bound = 1 / math.sqrt(fan_in) | |||||
| nn.init.uniform_(self.bias, -bound, bound) | |||||
| def forward(self, x, alpha): | |||||
| """ | |||||
| Args: | |||||
| x (tensor): feature to perform convolution on. | |||||
| alpha (tensor): calibration weight for the base weights. | |||||
| W_t = alpha_t * W_b | |||||
| """ | |||||
| if isinstance(alpha, list): | |||||
| w_alpha, b_alpha = alpha[0], alpha[1] | |||||
| else: | |||||
| w_alpha = alpha | |||||
| b_alpha = None | |||||
| _, _, c_out, c_in, kh, kw = self.weight.size() | |||||
| b, c_in, t, h, w = x.size() | |||||
| x = x.permute(0, 2, 1, 3, 4).reshape(1, -1, h, w) | |||||
| if self.cal_dim == 'cin': | |||||
| # w_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, 1, C, H(1), W(1) | |||||
| # corresponding to calibrating the input channel | |||||
| weight = (w_alpha.permute(0, 2, 1, 3, 4).unsqueeze(2) | |||||
| * self.weight).reshape(-1, c_in // self.groups, kh, kw) | |||||
| elif self.cal_dim == 'cout': | |||||
| # w_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, C, 1, H(1), W(1) | |||||
| # corresponding to calibrating the input channel | |||||
| weight = (w_alpha.permute(0, 2, 1, 3, 4).unsqueeze(3) | |||||
| * self.weight).reshape(-1, c_in // self.groups, kh, kw) | |||||
| bias = None | |||||
| if self.bias is not None: | |||||
| if b_alpha is not None: | |||||
| # b_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, C | |||||
| bias = (b_alpha.permute(0, 2, 1, 3, 4).squeeze() | |||||
| * self.bias).reshape(-1) | |||||
| else: | |||||
| bias = self.bias.repeat(b, t, 1).reshape(-1) | |||||
| output = F.conv2d( | |||||
| x, | |||||
| weight=weight, | |||||
| bias=bias, | |||||
| stride=self.stride[1:], | |||||
| padding=self.padding[1:], | |||||
| dilation=self.dilation[1:], | |||||
| groups=self.groups * b * t) | |||||
| output = output.view(b, t, c_out, output.size(-2), | |||||
| output.size(-1)).permute(0, 2, 1, 3, 4) | |||||
| return output | |||||
| def __repr__(self): | |||||
| return f'TAdaConv2d({self.in_channels}, {self.out_channels}, kernel_size={self.kernel_size}, ' +\ | |||||
| f"stride={self.stride}, padding={self.padding}, bias={self.bias is not None}, cal_dim=\"{self.cal_dim}\")" | |||||
| @@ -49,6 +49,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.ocr_detection: (Pipelines.ocr_detection, | Tasks.ocr_detection: (Pipelines.ocr_detection, | ||||
| 'damo/cv_resnet18_ocr-detection-line-level_damo'), | 'damo/cv_resnet18_ocr-detection-line-level_damo'), | ||||
| Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask_large') | Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask_large') | ||||
| Tasks.action_recognition: (Pipelines.action_recognition, | |||||
| 'damo/cv_TAdaConv_action-recognition'), | |||||
| } | } | ||||
| @@ -1,3 +1,4 @@ | |||||
| from .action_recognition_pipeline import ActionRecognitionPipeline | |||||
| from .image_cartoon_pipeline import ImageCartoonPipeline | from .image_cartoon_pipeline import ImageCartoonPipeline | ||||
| from .image_matting_pipeline import ImageMattingPipeline | from .image_matting_pipeline import ImageMattingPipeline | ||||
| from .ocr_detection_pipeline import OCRDetectionPipeline | from .ocr_detection_pipeline import OCRDetectionPipeline | ||||
| @@ -0,0 +1,65 @@ | |||||
| import math | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import PIL | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.action_recognition.models import BaseVideoModel | |||||
| from modelscope.pipelines.base import Input | |||||
| from modelscope.preprocessors.video import ReadVideoData | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from ..base import Pipeline | |||||
| from ..builder import PIPELINES | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.action_recognition, module_name=Pipelines.action_recognition) | |||||
| class ActionRecognitionPipeline(Pipeline): | |||||
| def __init__(self, model: str): | |||||
| super().__init__(model=model) | |||||
| model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | |||||
| logger.info(f'loading model from {model_path}') | |||||
| config_path = osp.join(self.model, ModelFile.CONFIGURATION) | |||||
| logger.info(f'loading config from {config_path}') | |||||
| self.cfg = Config.from_file(config_path) | |||||
| self.infer_model = BaseVideoModel(cfg=self.cfg).cuda() | |||||
| self.infer_model.eval() | |||||
| self.infer_model.load_state_dict(torch.load(model_path)['model_state']) | |||||
| self.label_mapping = self.cfg.label_mapping | |||||
| logger.info('load model done') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| if isinstance(input, str): | |||||
| video_input_data = ReadVideoData(self.cfg, input).cuda() | |||||
| else: | |||||
| raise TypeError(f'input should be a str,' | |||||
| f' but got {type(input)}') | |||||
| result = {'video_data': video_input_data} | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| pred = self.perform_inference(input['video_data']) | |||||
| output_label = self.label_mapping[str(pred)] | |||||
| return {'output_label': output_label} | |||||
| @torch.no_grad() | |||||
| def perform_inference(self, data, max_bsz=4): | |||||
| iter_num = math.ceil(data.size(0) / max_bsz) | |||||
| preds_list = [] | |||||
| for i in range(iter_num): | |||||
| preds_list.append( | |||||
| self.infer_model(data[i * max_bsz:(i + 1) * max_bsz])[0]) | |||||
| pred = torch.cat(preds_list, dim=0) | |||||
| return pred.mean(dim=0).argmax().item() | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -45,6 +45,12 @@ TASK_OUTPUTS = { | |||||
| Tasks.image_matting: ['output_png'], | Tasks.image_matting: ['output_png'], | ||||
| Tasks.image_generation: ['output_png'], | Tasks.image_generation: ['output_png'], | ||||
| # action recognition result for single video | |||||
| # { | |||||
| # "output_label": "abseiling" | |||||
| # } | |||||
| Tasks.action_recognition: ['output_label'], | |||||
| # pose estimation result for single sample | # pose estimation result for single sample | ||||
| # { | # { | ||||
| # "poses": np.array with shape [num_pose, num_keypoint, 3], | # "poses": np.array with shape [num_pose, num_keypoint, 3], | ||||
| @@ -5,7 +5,7 @@ from .base import Preprocessor | |||||
| from .builder import PREPROCESSORS, build_preprocessor | from .builder import PREPROCESSORS, build_preprocessor | ||||
| from .common import Compose | from .common import Compose | ||||
| from .image import LoadImage, load_image | from .image import LoadImage, load_image | ||||
| from .multi_model import OfaImageCaptionPreprocessor | |||||
| from .multi_modal import OfaImageCaptionPreprocessor | |||||
| from .nlp import * # noqa F403 | from .nlp import * # noqa F403 | ||||
| from .space.dialog_intent_prediction_preprocessor import * # noqa F403 | from .space.dialog_intent_prediction_preprocessor import * # noqa F403 | ||||
| from .space.dialog_modeling_preprocessor import * # noqa F403 | from .space.dialog_modeling_preprocessor import * # noqa F403 | ||||
| @@ -73,7 +73,7 @@ class OfaImageCaptionPreprocessor(Preprocessor): | |||||
| self.eos_item = torch.LongTensor([task.src_dict.eos()]) | self.eos_item = torch.LongTensor([task.src_dict.eos()]) | ||||
| self.pad_idx = task.src_dict.pad() | self.pad_idx = task.src_dict.pad() | ||||
| @type_assert(object, (str, tuple)) | |||||
| @type_assert(object, (str, tuple, Image.Image)) | |||||
| def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]: | def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]: | ||||
| def encode_text(text, length=None, append_bos=False, append_eos=False): | def encode_text(text, length=None, append_bos=False, append_eos=False): | ||||
| @@ -89,8 +89,8 @@ class OfaImageCaptionPreprocessor(Preprocessor): | |||||
| s = torch.cat([s, self.eos_item]) | s = torch.cat([s, self.eos_item]) | ||||
| return s | return s | ||||
| if isinstance(input, Image.Image): | |||||
| patch_image = self.patch_resize_transform(input).unsqueeze(0) | |||||
| if isinstance(data, Image.Image): | |||||
| patch_image = self.patch_resize_transform(data).unsqueeze(0) | |||||
| else: | else: | ||||
| patch_image = self.patch_resize_transform( | patch_image = self.patch_resize_transform( | ||||
| load_image(data)).unsqueeze(0) | load_image(data)).unsqueeze(0) | ||||
| @@ -0,0 +1,232 @@ | |||||
| import math | |||||
| import os | |||||
| import random | |||||
| import decord | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.utils.data | |||||
| import torch.utils.dlpack as dlpack | |||||
| import torchvision.transforms._transforms_video as transforms | |||||
| from decord import VideoReader | |||||
| from torchvision.transforms import Compose | |||||
| def ReadVideoData(cfg, video_path): | |||||
| """ simple interface to load video frames from file | |||||
| Args: | |||||
| cfg (Config): The global config object. | |||||
| video_path (str): video file path | |||||
| """ | |||||
| data = _decode_video(cfg, video_path) | |||||
| transform = kinetics400_tranform(cfg) | |||||
| data_list = [] | |||||
| for i in range(data.size(0)): | |||||
| for j in range(cfg.TEST.NUM_SPATIAL_CROPS): | |||||
| transform.transforms[1].set_spatial_index(j) | |||||
| data_list.append(transform(data[i])) | |||||
| return torch.stack(data_list, dim=0) | |||||
| def kinetics400_tranform(cfg): | |||||
| """ | |||||
| Configs the transform for the kinetics-400 dataset. | |||||
| We apply controlled spatial cropping and normalization. | |||||
| Args: | |||||
| cfg (Config): The global config object. | |||||
| """ | |||||
| resize_video = KineticsResizedCrop( | |||||
| short_side_range=[cfg.DATA.TEST_SCALE, cfg.DATA.TEST_SCALE], | |||||
| crop_size=cfg.DATA.TEST_CROP_SIZE, | |||||
| num_spatial_crops=cfg.TEST.NUM_SPATIAL_CROPS) | |||||
| std_transform_list = [ | |||||
| transforms.ToTensorVideo(), resize_video, | |||||
| transforms.NormalizeVideo( | |||||
| mean=cfg.DATA.MEAN, std=cfg.DATA.STD, inplace=True) | |||||
| ] | |||||
| return Compose(std_transform_list) | |||||
| def _interval_based_sampling(vid_length, vid_fps, target_fps, clip_idx, | |||||
| num_clips, num_frames, interval, minus_interval): | |||||
| """ | |||||
| Generates the frame index list using interval based sampling. | |||||
| Args: | |||||
| vid_length (int): the length of the whole video (valid selection range). | |||||
| vid_fps (int): the original video fps | |||||
| target_fps (int): the normalized video fps | |||||
| clip_idx (int): -1 for random temporal sampling, and positive values for | |||||
| sampling specific clip from the video | |||||
| num_clips (int): the total clips to be sampled from each video. | |||||
| combined with clip_idx, the sampled video is the "clip_idx-th" | |||||
| video from "num_clips" videos. | |||||
| num_frames (int): number of frames in each sampled clips. | |||||
| interval (int): the interval to sample each frame. | |||||
| minus_interval (bool): control the end index | |||||
| Returns: | |||||
| index (tensor): the sampled frame indexes | |||||
| """ | |||||
| if num_frames == 1: | |||||
| index = [random.randint(0, vid_length - 1)] | |||||
| else: | |||||
| # transform FPS | |||||
| clip_length = num_frames * interval * vid_fps / target_fps | |||||
| max_idx = max(vid_length - clip_length, 0) | |||||
| start_idx = clip_idx * math.floor(max_idx / (num_clips - 1)) | |||||
| if minus_interval: | |||||
| end_idx = start_idx + clip_length - interval | |||||
| else: | |||||
| end_idx = start_idx + clip_length - 1 | |||||
| index = torch.linspace(start_idx, end_idx, num_frames) | |||||
| index = torch.clamp(index, 0, vid_length - 1).long() | |||||
| return index | |||||
| def _decode_video_frames_list(cfg, frames_list, vid_fps): | |||||
| """ | |||||
| Decodes the video given the numpy frames. | |||||
| Args: | |||||
| cfg (Config): The global config object. | |||||
| frames_list (list): all frames for a video, the frames should be numpy array. | |||||
| vid_fps (int): the fps of this video. | |||||
| Returns: | |||||
| frames (Tensor): video tensor data | |||||
| """ | |||||
| assert isinstance(frames_list, list) | |||||
| num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS | |||||
| frame_list = [] | |||||
| for clip_idx in range(num_clips_per_video): | |||||
| # for each clip in the video, | |||||
| # a list is generated before decoding the specified frames from the video | |||||
| list_ = _interval_based_sampling( | |||||
| len(frames_list), vid_fps, cfg.DATA.TARGET_FPS, clip_idx, | |||||
| num_clips_per_video, cfg.DATA.NUM_INPUT_FRAMES, | |||||
| cfg.DATA.SAMPLING_RATE, cfg.DATA.MINUS_INTERVAL) | |||||
| frames = None | |||||
| frames = torch.from_numpy( | |||||
| np.stack([frames_list[l_index] for l_index in list_.tolist()], | |||||
| axis=0)) | |||||
| frame_list.append(frames) | |||||
| frames = torch.stack(frame_list) | |||||
| if num_clips_per_video == 1: | |||||
| frames = frames.squeeze(0) | |||||
| return frames | |||||
| def _decode_video(cfg, path): | |||||
| """ | |||||
| Decodes the video given the numpy frames. | |||||
| Args: | |||||
| path (str): video file path. | |||||
| Returns: | |||||
| frames (Tensor): video tensor data | |||||
| """ | |||||
| vr = VideoReader(path) | |||||
| num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS | |||||
| frame_list = [] | |||||
| for clip_idx in range(num_clips_per_video): | |||||
| # for each clip in the video, | |||||
| # a list is generated before decoding the specified frames from the video | |||||
| list_ = _interval_based_sampling( | |||||
| len(vr), vr.get_avg_fps(), cfg.DATA.TARGET_FPS, clip_idx, | |||||
| num_clips_per_video, cfg.DATA.NUM_INPUT_FRAMES, | |||||
| cfg.DATA.SAMPLING_RATE, cfg.DATA.MINUS_INTERVAL) | |||||
| frames = None | |||||
| if path.endswith('.avi'): | |||||
| append_list = torch.arange(0, list_[0], 4) | |||||
| frames = dlpack.from_dlpack( | |||||
| vr.get_batch(torch.cat([append_list, | |||||
| list_])).to_dlpack()).clone() | |||||
| frames = frames[append_list.shape[0]:] | |||||
| else: | |||||
| frames = dlpack.from_dlpack( | |||||
| vr.get_batch(list_).to_dlpack()).clone() | |||||
| frame_list.append(frames) | |||||
| frames = torch.stack(frame_list) | |||||
| if num_clips_per_video == 1: | |||||
| frames = frames.squeeze(0) | |||||
| del vr | |||||
| return frames | |||||
| class KineticsResizedCrop(object): | |||||
| """Perform resize and crop for kinetics-400 dataset | |||||
| Args: | |||||
| short_side_range (list): The length of short side range. In inference, this shoudle be [256, 256] | |||||
| crop_size (int): The cropped size for frames. | |||||
| num_spatial_crops (int): The number of the cropped spatial regions in each video. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| short_side_range, | |||||
| crop_size, | |||||
| num_spatial_crops=1, | |||||
| ): | |||||
| self.idx = -1 | |||||
| self.short_side_range = short_side_range | |||||
| self.crop_size = int(crop_size) | |||||
| self.num_spatial_crops = num_spatial_crops | |||||
| def _get_controlled_crop(self, clip): | |||||
| """Perform controlled crop for video tensor. | |||||
| Args: | |||||
| clip (Tensor): the video data, the shape is [T, C, H, W] | |||||
| """ | |||||
| _, _, clip_height, clip_width = clip.shape | |||||
| length = self.short_side_range[0] | |||||
| if clip_height < clip_width: | |||||
| new_clip_height = int(length) | |||||
| new_clip_width = int(clip_width / clip_height * new_clip_height) | |||||
| new_clip = torch.nn.functional.interpolate( | |||||
| clip, size=(new_clip_height, new_clip_width), mode='bilinear') | |||||
| else: | |||||
| new_clip_width = int(length) | |||||
| new_clip_height = int(clip_height / clip_width * new_clip_width) | |||||
| new_clip = torch.nn.functional.interpolate( | |||||
| clip, size=(new_clip_height, new_clip_width), mode='bilinear') | |||||
| x_max = int(new_clip_width - self.crop_size) | |||||
| y_max = int(new_clip_height - self.crop_size) | |||||
| if self.num_spatial_crops == 1: | |||||
| x = x_max // 2 | |||||
| y = y_max // 2 | |||||
| elif self.num_spatial_crops == 3: | |||||
| if self.idx == 0: | |||||
| if new_clip_width == length: | |||||
| x = x_max // 2 | |||||
| y = 0 | |||||
| elif new_clip_height == length: | |||||
| x = 0 | |||||
| y = y_max // 2 | |||||
| elif self.idx == 1: | |||||
| x = x_max // 2 | |||||
| y = y_max // 2 | |||||
| elif self.idx == 2: | |||||
| if new_clip_width == length: | |||||
| x = x_max // 2 | |||||
| y = y_max | |||||
| elif new_clip_height == length: | |||||
| x = x_max | |||||
| y = y_max // 2 | |||||
| return new_clip[:, :, y:y + self.crop_size, x:x + self.crop_size] | |||||
| def set_spatial_index(self, idx): | |||||
| """Set the spatial cropping index for controlled cropping.. | |||||
| Args: | |||||
| idx (int): the spatial index. The value should be in [0, 1, 2], means [left, center, right], respectively. | |||||
| """ | |||||
| self.idx = idx | |||||
| def __call__(self, clip): | |||||
| return self._get_controlled_crop(clip) | |||||
| @@ -29,6 +29,7 @@ class Tasks(object): | |||||
| image_generation = 'image-generation' | image_generation = 'image-generation' | ||||
| image_matting = 'image-matting' | image_matting = 'image-matting' | ||||
| ocr_detection = 'ocr-detection' | ocr_detection = 'ocr-detection' | ||||
| action_recognition = 'action-recognition' | |||||
| # nlp tasks | # nlp tasks | ||||
| zero_shot_classification = 'zero-shot-classification' | zero_shot_classification = 'zero-shot-classification' | ||||
| @@ -2,21 +2,39 @@ | |||||
| import os | import os | ||||
| import os.path as osp | import os.path as osp | ||||
| from typing import List, Union | |||||
| from typing import List, Optional, Union | |||||
| from numpy import deprecate | |||||
| from requests import HTTPError | |||||
| from modelscope.hub.file_download import model_file_download | from modelscope.hub.file_download import model_file_download | ||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.hub.utils.utils import get_cache_dir | |||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import ModelFile | from modelscope.utils.constant import ModelFile | ||||
| # temp solution before the hub-cache is in place | |||||
| @deprecate | |||||
| def get_model_cache_dir(model_id: str): | |||||
| return os.path.join(get_cache_dir(), model_id) | |||||
| def create_model_if_not_exist( | |||||
| api, | |||||
| model_id: str, | |||||
| chinese_name: str, | |||||
| visibility: Optional[int] = 5, # 1-private, 5-public | |||||
| license: Optional[str] = 'apache-2.0', | |||||
| revision: Optional[str] = 'master'): | |||||
| exists = True | |||||
| try: | |||||
| api.get_model(model_id=model_id, revision=revision) | |||||
| except HTTPError: | |||||
| exists = False | |||||
| if exists: | |||||
| print(f'model {model_id} already exists, skip creation.') | |||||
| return False | |||||
| else: | |||||
| api.create_model( | |||||
| model_id=model_id, | |||||
| chinese_name=chinese_name, | |||||
| visibility=visibility, | |||||
| license=license) | |||||
| print(f'model {model_id} successfully created.') | |||||
| return True | |||||
| def read_config(model_id_or_path: str): | def read_config(model_id_or_path: str): | ||||
| @@ -78,7 +78,7 @@ class Registry(object): | |||||
| f'{self._name}[{default_group}] and will ' | f'{self._name}[{default_group}] and will ' | ||||
| 'be overwritten') | 'be overwritten') | ||||
| logger.warning(f'{self._modules[default_group][module_name]}' | logger.warning(f'{self._modules[default_group][module_name]}' | ||||
| 'to {module_cls}') | |||||
| f'to {module_cls}') | |||||
| # also register module in the default group for faster access | # also register module in the default group for faster access | ||||
| # only by module name | # only by module name | ||||
| self._modules[default_group][module_name] = module_cls | self._modules[default_group][module_name] = module_cls | ||||
| @@ -1,2 +1,3 @@ | |||||
| decord>=0.6.0 | |||||
| easydict | easydict | ||||
| tf_slim | tf_slim | ||||
| @@ -0,0 +1,33 @@ | |||||
| import unittest | |||||
| from maas_hub.maas_api import MaasApi | |||||
| from modelscope.utils.hub import create_model_if_not_exist | |||||
| USER_NAME = 'maasadmin' | |||||
| PASSWORD = '12345678' | |||||
| class HubExampleTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| self.api = MaasApi() | |||||
| # note this is temporary before official account management is ready | |||||
| self.api.login(USER_NAME, PASSWORD) | |||||
| @unittest.skip('to be used for local test only') | |||||
| def test_example_model_creation(self): | |||||
| # ATTENTION:change to proper model names before use | |||||
| model_name = 'cv_unet_person-image-cartoon_compound-models' | |||||
| model_chinese_name = '达摩卡通化模型' | |||||
| model_org = 'damo' | |||||
| model_id = '%s/%s' % (model_org, model_name) | |||||
| created = create_model_if_not_exist(self.api, model_id, | |||||
| model_chinese_name) | |||||
| if not created: | |||||
| print('!! NOT created since model already exists !!') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -1,6 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | import os | ||||
| import os.path as osp | |||||
| import subprocess | import subprocess | ||||
| import tempfile | import tempfile | ||||
| import unittest | import unittest | ||||
| @@ -8,7 +7,6 @@ import uuid | |||||
| from modelscope.hub.api import HubApi, ModelScopeConfig | from modelscope.hub.api import HubApi, ModelScopeConfig | ||||
| from modelscope.hub.file_download import model_file_download | from modelscope.hub.file_download import model_file_download | ||||
| from modelscope.hub.repository import Repository | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.hub.utils.utils import get_gitlab_domain | from modelscope.hub.utils.utils import get_gitlab_domain | ||||
| @@ -0,0 +1,58 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # !/usr/bin/env python | |||||
| import os.path as osp | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| import cv2 | |||||
| from modelscope.fileio import File | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pydatasets import PyDataset | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class ActionRecognitionTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/cv_TAdaConv_action-recognition' | |||||
| @unittest.skip('deprecated, download model from model hub instead') | |||||
| def test_run_with_direct_file_download(self): | |||||
| model_path = 'https://aquila2-online-models.oss-cn-shanghai.aliyuncs.com/maas_test/pytorch_model.pt' | |||||
| config_path = 'https://aquila2-online-models.oss-cn-shanghai.aliyuncs.com/maas_test/configuration.json' | |||||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||||
| model_file = osp.join(tmp_dir, ModelFile.TORCH_MODEL_FILE) | |||||
| with open(model_file, 'wb') as ofile1: | |||||
| ofile1.write(File.read(model_path)) | |||||
| config_file = osp.join(tmp_dir, ModelFile.CONFIGURATION) | |||||
| with open(config_file, 'wb') as ofile2: | |||||
| ofile2.write(File.read(config_path)) | |||||
| recognition_pipeline = pipeline( | |||||
| Tasks.action_recognition, model=tmp_dir) | |||||
| result = recognition_pipeline( | |||||
| 'data/test/videos/action_recognition_test_video.mp4') | |||||
| print(f'recognition output: {result}.') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| recognition_pipeline = pipeline( | |||||
| Tasks.action_recognition, model=self.model_id) | |||||
| result = recognition_pipeline( | |||||
| 'data/test/videos/action_recognition_test_video.mp4') | |||||
| print(f'recognition output: {result}.') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub_default_model(self): | |||||
| recognition_pipeline = pipeline(Tasks.action_recognition) | |||||
| result = recognition_pipeline( | |||||
| 'data/test/videos/action_recognition_test_video.mp4') | |||||
| print(f'recognition output: {result}.') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -60,7 +60,7 @@ class ImageMattingTest(unittest.TestCase): | |||||
| cv2.imwrite('result.png', result['output_png']) | cv2.imwrite('result.png', result['output_png']) | ||||
| print(f'Output written to {osp.abspath("result.png")}') | print(f'Output written to {osp.abspath("result.png")}') | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_modelscope_dataset(self): | def test_run_with_modelscope_dataset(self): | ||||
| dataset = PyDataset.load('beans', split='train', target='image') | dataset = PyDataset.load('beans', split='train', target='image') | ||||
| img_matting = pipeline(Tasks.image_matting, model=self.model_id) | img_matting = pipeline(Tasks.image_matting, model=self.model_id) | ||||
| @@ -33,8 +33,6 @@ class ImgPreprocessor(Preprocessor): | |||||
| class PyDatasetTest(unittest.TestCase): | class PyDatasetTest(unittest.TestCase): | ||||
| @unittest.skipUnless(test_level() >= 2, | |||||
| 'skip test due to dataset api problem') | |||||
| def test_ds_basic(self): | def test_ds_basic(self): | ||||
| ms_ds_full = PyDataset.load('squad') | ms_ds_full = PyDataset.load('squad') | ||||
| ms_ds_full_hf = hfdata.load_dataset('squad') | ms_ds_full_hf = hfdata.load_dataset('squad') | ||||