Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10773667master^2
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:f4b7e23f02a35136707ac7862e0a8468797f239e89497351847cfacb2a9c24f6 | |||||
| size 202112 | |||||
| @@ -151,6 +151,7 @@ class Pipelines(object): | |||||
| image_denoise = 'nafnet-image-denoise' | image_denoise = 'nafnet-image-denoise' | ||||
| person_image_cartoon = 'unet-person-image-cartoon' | person_image_cartoon = 'unet-person-image-cartoon' | ||||
| ocr_detection = 'resnet18-ocr-detection' | ocr_detection = 'resnet18-ocr-detection' | ||||
| table_recognition = 'dla34-table-recognition' | |||||
| action_recognition = 'TAdaConv_action-recognition' | action_recognition = 'TAdaConv_action-recognition' | ||||
| animal_recognition = 'resnet101-animal-recognition' | animal_recognition = 'resnet101-animal-recognition' | ||||
| general_recognition = 'resnet101-general-recognition' | general_recognition = 'resnet101-general-recognition' | ||||
| @@ -59,6 +59,7 @@ TASK_OUTPUTS = { | |||||
| # [x1, y1, x2, y2, x3, y3, x4, y4] | # [x1, y1, x2, y2, x3, y3, x4, y4] | ||||
| # } | # } | ||||
| Tasks.ocr_detection: [OutputKeys.POLYGONS], | Tasks.ocr_detection: [OutputKeys.POLYGONS], | ||||
| Tasks.table_recognition: [OutputKeys.POLYGONS], | |||||
| # ocr recognition result for single sample | # ocr recognition result for single sample | ||||
| # { | # { | ||||
| @@ -82,6 +82,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/cv_unet_person-image-cartoon_compound-models'), | 'damo/cv_unet_person-image-cartoon_compound-models'), | ||||
| Tasks.ocr_detection: (Pipelines.ocr_detection, | Tasks.ocr_detection: (Pipelines.ocr_detection, | ||||
| 'damo/cv_resnet18_ocr-detection-line-level_damo'), | 'damo/cv_resnet18_ocr-detection-line-level_damo'), | ||||
| Tasks.table_recognition: | |||||
| (Pipelines.table_recognition, | |||||
| 'damo/cv_dla34_table-structure-recognition_cycle-centernet'), | |||||
| Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | ||||
| Tasks.feature_extraction: (Pipelines.feature_extraction, | Tasks.feature_extraction: (Pipelines.feature_extraction, | ||||
| 'damo/pert_feature-extraction_base-test'), | 'damo/pert_feature-extraction_base-test'), | ||||
| @@ -41,6 +41,7 @@ if TYPE_CHECKING: | |||||
| from .live_category_pipeline import LiveCategoryPipeline | from .live_category_pipeline import LiveCategoryPipeline | ||||
| from .ocr_detection_pipeline import OCRDetectionPipeline | from .ocr_detection_pipeline import OCRDetectionPipeline | ||||
| from .ocr_recognition_pipeline import OCRRecognitionPipeline | from .ocr_recognition_pipeline import OCRRecognitionPipeline | ||||
| from .table_recognition_pipeline import TableRecognitionPipeline | |||||
| from .skin_retouching_pipeline import SkinRetouchingPipeline | from .skin_retouching_pipeline import SkinRetouchingPipeline | ||||
| from .tinynas_classification_pipeline import TinynasClassificationPipeline | from .tinynas_classification_pipeline import TinynasClassificationPipeline | ||||
| from .video_category_pipeline import VideoCategoryPipeline | from .video_category_pipeline import VideoCategoryPipeline | ||||
| @@ -108,6 +109,7 @@ else: | |||||
| 'image_inpainting_pipeline': ['ImageInpaintingPipeline'], | 'image_inpainting_pipeline': ['ImageInpaintingPipeline'], | ||||
| 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | ||||
| 'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], | 'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], | ||||
| 'table_recognition_pipeline': ['TableRecognitionPipeline'], | |||||
| 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | ||||
| 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], | 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], | ||||
| 'video_category_pipeline': ['VideoCategoryPipeline'], | 'video_category_pipeline': ['VideoCategoryPipeline'], | ||||
| @@ -0,0 +1,655 @@ | |||||
| # ------------------------------------------------------------------------------ | |||||
| # The implementation is adopted from CenterNet, | |||||
| # made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git | |||||
| # ------------------------------------------------------------------------------ | |||||
| import math | |||||
| from os.path import join | |||||
| import numpy as np | |||||
| import torch | |||||
| from torch import nn | |||||
| BatchNorm = nn.BatchNorm2d | |||||
| class BasicBlock(nn.Module): | |||||
| def __init__(self, inplanes, planes, stride=1, dilation=1): | |||||
| super(BasicBlock, self).__init__() | |||||
| self.conv1 = nn.Conv2d( | |||||
| inplanes, | |||||
| planes, | |||||
| kernel_size=3, | |||||
| stride=stride, | |||||
| padding=dilation, | |||||
| bias=False, | |||||
| dilation=dilation) | |||||
| self.bn1 = BatchNorm(planes) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.conv2 = nn.Conv2d( | |||||
| planes, | |||||
| planes, | |||||
| kernel_size=3, | |||||
| stride=1, | |||||
| padding=dilation, | |||||
| bias=False, | |||||
| dilation=dilation) | |||||
| self.bn2 = BatchNorm(planes) | |||||
| self.stride = stride | |||||
| def forward(self, x, residual=None): | |||||
| if residual is None: | |||||
| residual = x | |||||
| out = self.conv1(x) | |||||
| out = self.bn1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| out = self.bn2(out) | |||||
| out += residual | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class Bottleneck(nn.Module): | |||||
| expansion = 2 | |||||
| def __init__(self, inplanes, planes, stride=1, dilation=1): | |||||
| super(Bottleneck, self).__init__() | |||||
| expansion = Bottleneck.expansion | |||||
| bottle_planes = planes // expansion | |||||
| self.conv1 = nn.Conv2d( | |||||
| inplanes, bottle_planes, kernel_size=1, bias=False) | |||||
| self.bn1 = BatchNorm(bottle_planes) | |||||
| self.conv2 = nn.Conv2d( | |||||
| bottle_planes, | |||||
| bottle_planes, | |||||
| kernel_size=3, | |||||
| stride=stride, | |||||
| padding=dilation, | |||||
| bias=False, | |||||
| dilation=dilation) | |||||
| self.bn2 = BatchNorm(bottle_planes) | |||||
| self.conv3 = nn.Conv2d( | |||||
| bottle_planes, planes, kernel_size=1, bias=False) | |||||
| self.bn3 = BatchNorm(planes) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.stride = stride | |||||
| def forward(self, x, residual=None): | |||||
| if residual is None: | |||||
| residual = x | |||||
| out = self.conv1(x) | |||||
| out = self.bn1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| out = self.bn2(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv3(out) | |||||
| out = self.bn3(out) | |||||
| out += residual | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class BottleneckX(nn.Module): | |||||
| expansion = 2 | |||||
| cardinality = 32 | |||||
| def __init__(self, inplanes, planes, stride=1, dilation=1): | |||||
| super(BottleneckX, self).__init__() | |||||
| cardinality = BottleneckX.cardinality | |||||
| bottle_planes = planes * cardinality // 32 | |||||
| self.conv1 = nn.Conv2d( | |||||
| inplanes, bottle_planes, kernel_size=1, bias=False) | |||||
| self.bn1 = BatchNorm(bottle_planes) | |||||
| self.conv2 = nn.Conv2d( | |||||
| bottle_planes, | |||||
| bottle_planes, | |||||
| kernel_size=3, | |||||
| stride=stride, | |||||
| padding=dilation, | |||||
| bias=False, | |||||
| dilation=dilation, | |||||
| groups=cardinality) | |||||
| self.bn2 = BatchNorm(bottle_planes) | |||||
| self.conv3 = nn.Conv2d( | |||||
| bottle_planes, planes, kernel_size=1, bias=False) | |||||
| self.bn3 = BatchNorm(planes) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.stride = stride | |||||
| def forward(self, x, residual=None): | |||||
| if residual is None: | |||||
| residual = x | |||||
| out = self.conv1(x) | |||||
| out = self.bn1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| out = self.bn2(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv3(out) | |||||
| out = self.bn3(out) | |||||
| out += residual | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class Root(nn.Module): | |||||
| def __init__(self, in_channels, out_channels, kernel_size, residual): | |||||
| super(Root, self).__init__() | |||||
| self.conv = nn.Conv2d( | |||||
| in_channels, | |||||
| out_channels, | |||||
| 1, | |||||
| stride=1, | |||||
| bias=False, | |||||
| padding=(kernel_size - 1) // 2) | |||||
| self.bn = BatchNorm(out_channels) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.residual = residual | |||||
| def forward(self, *x): | |||||
| children = x | |||||
| x = self.conv(torch.cat(x, 1)) | |||||
| x = self.bn(x) | |||||
| if self.residual: | |||||
| x += children[0] | |||||
| x = self.relu(x) | |||||
| return x | |||||
| class Tree(nn.Module): | |||||
| def __init__(self, | |||||
| levels, | |||||
| block, | |||||
| in_channels, | |||||
| out_channels, | |||||
| stride=1, | |||||
| level_root=False, | |||||
| root_dim=0, | |||||
| root_kernel_size=1, | |||||
| dilation=1, | |||||
| root_residual=False): | |||||
| super(Tree, self).__init__() | |||||
| if root_dim == 0: | |||||
| root_dim = 2 * out_channels | |||||
| if level_root: | |||||
| root_dim += in_channels | |||||
| if levels == 1: | |||||
| self.tree1 = block( | |||||
| in_channels, out_channels, stride, dilation=dilation) | |||||
| self.tree2 = block( | |||||
| out_channels, out_channels, 1, dilation=dilation) | |||||
| else: | |||||
| self.tree1 = Tree( | |||||
| levels - 1, | |||||
| block, | |||||
| in_channels, | |||||
| out_channels, | |||||
| stride, | |||||
| root_dim=0, | |||||
| root_kernel_size=root_kernel_size, | |||||
| dilation=dilation, | |||||
| root_residual=root_residual) | |||||
| self.tree2 = Tree( | |||||
| levels - 1, | |||||
| block, | |||||
| out_channels, | |||||
| out_channels, | |||||
| root_dim=root_dim + out_channels, | |||||
| root_kernel_size=root_kernel_size, | |||||
| dilation=dilation, | |||||
| root_residual=root_residual) | |||||
| if levels == 1: | |||||
| self.root = Root(root_dim, out_channels, root_kernel_size, | |||||
| root_residual) | |||||
| self.level_root = level_root | |||||
| self.root_dim = root_dim | |||||
| self.downsample = None | |||||
| self.project = None | |||||
| self.levels = levels | |||||
| if stride > 1: | |||||
| self.downsample = nn.MaxPool2d(stride, stride=stride) | |||||
| if in_channels != out_channels: | |||||
| self.project = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| bias=False), BatchNorm(out_channels)) | |||||
| def forward(self, x, residual=None, children=None): | |||||
| children = [] if children is None else children | |||||
| bottom = self.downsample(x) if self.downsample else x | |||||
| residual = self.project(bottom) if self.project else bottom | |||||
| if self.level_root: | |||||
| children.append(bottom) | |||||
| x1 = self.tree1(x, residual) | |||||
| if self.levels == 1: | |||||
| x2 = self.tree2(x1) | |||||
| x = self.root(x2, x1, *children) | |||||
| else: | |||||
| children.append(x1) | |||||
| x = self.tree2(x1, children=children) | |||||
| return x | |||||
| class DLA(nn.Module): | |||||
| def __init__(self, | |||||
| levels, | |||||
| channels, | |||||
| num_classes=1000, | |||||
| block=BasicBlock, | |||||
| residual_root=False, | |||||
| return_levels=False, | |||||
| pool_size=7, | |||||
| linear_root=False): | |||||
| super(DLA, self).__init__() | |||||
| self.channels = channels | |||||
| self.return_levels = return_levels | |||||
| self.num_classes = num_classes | |||||
| self.base_layer = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| 3, channels[0], kernel_size=7, stride=1, padding=3, | |||||
| bias=False), BatchNorm(channels[0]), nn.ReLU(inplace=True)) | |||||
| self.level0 = self._make_conv_level(channels[0], channels[0], | |||||
| levels[0]) | |||||
| self.level1 = self._make_conv_level( | |||||
| channels[0], channels[1], levels[1], stride=2) | |||||
| self.level2 = Tree( | |||||
| levels[2], | |||||
| block, | |||||
| channels[1], | |||||
| channels[2], | |||||
| 2, | |||||
| level_root=False, | |||||
| root_residual=residual_root) | |||||
| self.level3 = Tree( | |||||
| levels[3], | |||||
| block, | |||||
| channels[2], | |||||
| channels[3], | |||||
| 2, | |||||
| level_root=True, | |||||
| root_residual=residual_root) | |||||
| self.level4 = Tree( | |||||
| levels[4], | |||||
| block, | |||||
| channels[3], | |||||
| channels[4], | |||||
| 2, | |||||
| level_root=True, | |||||
| root_residual=residual_root) | |||||
| self.level5 = Tree( | |||||
| levels[5], | |||||
| block, | |||||
| channels[4], | |||||
| channels[5], | |||||
| 2, | |||||
| level_root=True, | |||||
| root_residual=residual_root) | |||||
| self.avgpool = nn.AvgPool2d(pool_size) | |||||
| self.fc = nn.Conv2d( | |||||
| channels[-1], | |||||
| num_classes, | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| padding=0, | |||||
| bias=True) | |||||
| for m in self.modules(): | |||||
| if isinstance(m, nn.Conv2d): | |||||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||||
| m.weight.data.normal_(0, math.sqrt(2. / n)) | |||||
| elif isinstance(m, BatchNorm): | |||||
| m.weight.data.fill_(1) | |||||
| m.bias.data.zero_() | |||||
| def _make_level(self, block, inplanes, planes, blocks, stride=1): | |||||
| downsample = None | |||||
| if stride != 1 or inplanes != planes: | |||||
| downsample = nn.Sequential( | |||||
| nn.MaxPool2d(stride, stride=stride), | |||||
| nn.Conv2d( | |||||
| inplanes, planes, kernel_size=1, stride=1, bias=False), | |||||
| BatchNorm(planes), | |||||
| ) | |||||
| layers = [] | |||||
| layers.append(block(inplanes, planes, stride, downsample=downsample)) | |||||
| for i in range(1, blocks): | |||||
| layers.append(block(inplanes, planes)) | |||||
| return nn.Sequential(*layers) | |||||
| def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): | |||||
| modules = [] | |||||
| for i in range(convs): | |||||
| modules.extend([ | |||||
| nn.Conv2d( | |||||
| inplanes, | |||||
| planes, | |||||
| kernel_size=3, | |||||
| stride=stride if i == 0 else 1, | |||||
| padding=dilation, | |||||
| bias=False, | |||||
| dilation=dilation), | |||||
| BatchNorm(planes), | |||||
| nn.ReLU(inplace=True) | |||||
| ]) | |||||
| inplanes = planes | |||||
| return nn.Sequential(*modules) | |||||
| def forward(self, x): | |||||
| y = [] | |||||
| x = self.base_layer(x) | |||||
| for i in range(6): | |||||
| x = getattr(self, 'level{}'.format(i))(x) | |||||
| y.append(x) | |||||
| if self.return_levels: | |||||
| return y | |||||
| else: | |||||
| x = self.avgpool(x) | |||||
| x = self.fc(x) | |||||
| x = x.view(x.size(0), -1) | |||||
| return x | |||||
| def dla34(pretrained, **kwargs): # DLA-34 | |||||
| model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], | |||||
| block=BasicBlock, | |||||
| **kwargs) | |||||
| return model | |||||
| def dla46_c(pretrained=None, **kwargs): # DLA-46-C | |||||
| Bottleneck.expansion = 2 | |||||
| model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], | |||||
| block=Bottleneck, | |||||
| **kwargs) | |||||
| return model | |||||
| def dla46x_c(pretrained=None, **kwargs): # DLA-X-46-C | |||||
| BottleneckX.expansion = 2 | |||||
| model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], | |||||
| block=BottleneckX, | |||||
| **kwargs) | |||||
| return model | |||||
| def dla60x_c(pretrained, **kwargs): # DLA-X-60-C | |||||
| BottleneckX.expansion = 2 | |||||
| model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 64, 64, 128, 256], | |||||
| block=BottleneckX, | |||||
| **kwargs) | |||||
| return model | |||||
| def dla60(pretrained=None, **kwargs): # DLA-60 | |||||
| Bottleneck.expansion = 2 | |||||
| model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], | |||||
| block=Bottleneck, | |||||
| **kwargs) | |||||
| return model | |||||
| def dla60x(pretrained=None, **kwargs): # DLA-X-60 | |||||
| BottleneckX.expansion = 2 | |||||
| model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], | |||||
| block=BottleneckX, | |||||
| **kwargs) | |||||
| return model | |||||
| def dla102(pretrained=None, **kwargs): # DLA-102 | |||||
| Bottleneck.expansion = 2 | |||||
| model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], | |||||
| block=Bottleneck, | |||||
| residual_root=True, | |||||
| **kwargs) | |||||
| return model | |||||
| def dla102x(pretrained=None, **kwargs): # DLA-X-102 | |||||
| BottleneckX.expansion = 2 | |||||
| model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], | |||||
| block=BottleneckX, | |||||
| residual_root=True, | |||||
| **kwargs) | |||||
| return model | |||||
| def dla102x2(pretrained=None, **kwargs): # DLA-X-102 64 | |||||
| BottleneckX.cardinality = 64 | |||||
| model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], | |||||
| block=BottleneckX, | |||||
| residual_root=True, | |||||
| **kwargs) | |||||
| return model | |||||
| def dla169(pretrained=None, **kwargs): # DLA-169 | |||||
| Bottleneck.expansion = 2 | |||||
| model = DLA([1, 1, 2, 3, 5, 1], [16, 32, 128, 256, 512, 1024], | |||||
| block=Bottleneck, | |||||
| residual_root=True, | |||||
| **kwargs) | |||||
| return model | |||||
| def set_bn(bn): | |||||
| global BatchNorm | |||||
| BatchNorm = bn | |||||
| dla.BatchNorm = bn | |||||
| class Identity(nn.Module): | |||||
| def __init__(self): | |||||
| super(Identity, self).__init__() | |||||
| def forward(self, x): | |||||
| return x | |||||
| def fill_up_weights(up): | |||||
| w = up.weight.data | |||||
| f = math.ceil(w.size(2) / 2) | |||||
| c = (2 * f - 1 - f % 2) / (2. * f) | |||||
| for i in range(w.size(2)): | |||||
| for j in range(w.size(3)): | |||||
| w[0, 0, i, j] = \ | |||||
| (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) | |||||
| for c in range(1, w.size(0)): | |||||
| w[c, 0, :, :] = w[0, 0, :, :] | |||||
| class IDAUp(nn.Module): | |||||
| def __init__(self, node_kernel, out_dim, channels, up_factors): | |||||
| super(IDAUp, self).__init__() | |||||
| self.channels = channels | |||||
| self.out_dim = out_dim | |||||
| for i, c in enumerate(channels): | |||||
| if c == out_dim: | |||||
| proj = Identity() | |||||
| else: | |||||
| proj = nn.Sequential( | |||||
| nn.Conv2d(c, out_dim, kernel_size=1, stride=1, bias=False), | |||||
| BatchNorm(out_dim), nn.ReLU(inplace=True)) | |||||
| f = int(up_factors[i]) | |||||
| if f == 1: | |||||
| up = Identity() | |||||
| else: | |||||
| up = nn.ConvTranspose2d( | |||||
| out_dim, | |||||
| out_dim, | |||||
| f * 2, | |||||
| stride=f, | |||||
| padding=f // 2, | |||||
| output_padding=0, | |||||
| groups=out_dim, | |||||
| bias=False) | |||||
| fill_up_weights(up) | |||||
| setattr(self, 'proj_' + str(i), proj) | |||||
| setattr(self, 'up_' + str(i), up) | |||||
| for i in range(1, len(channels)): | |||||
| node = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| out_dim * 2, | |||||
| out_dim, | |||||
| kernel_size=node_kernel, | |||||
| stride=1, | |||||
| padding=node_kernel // 2, | |||||
| bias=False), BatchNorm(out_dim), nn.ReLU(inplace=True)) | |||||
| setattr(self, 'node_' + str(i), node) | |||||
| for m in self.modules(): | |||||
| if isinstance(m, nn.Conv2d): | |||||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||||
| m.weight.data.normal_(0, math.sqrt(2. / n)) | |||||
| elif isinstance(m, BatchNorm): | |||||
| m.weight.data.fill_(1) | |||||
| m.bias.data.zero_() | |||||
| def forward(self, layers): | |||||
| assert len(self.channels) == len(layers), \ | |||||
| '{} vs {} layers'.format(len(self.channels), len(layers)) | |||||
| layers = list(layers) | |||||
| for i, l in enumerate(layers): | |||||
| upsample = getattr(self, 'up_' + str(i)) | |||||
| project = getattr(self, 'proj_' + str(i)) | |||||
| layers[i] = upsample(project(l)) | |||||
| x = layers[0] | |||||
| y = [] | |||||
| for i in range(1, len(layers)): | |||||
| node = getattr(self, 'node_' + str(i)) | |||||
| x = node(torch.cat([x, layers[i]], 1)) | |||||
| y.append(x) | |||||
| return x, y | |||||
| class DLAUp(nn.Module): | |||||
| def __init__(self, channels, scales=(1, 2, 4, 8, 16), in_channels=None): | |||||
| super(DLAUp, self).__init__() | |||||
| if in_channels is None: | |||||
| in_channels = channels | |||||
| self.channels = channels | |||||
| channels = list(channels) | |||||
| scales = np.array(scales, dtype=int) | |||||
| for i in range(len(channels) - 1): | |||||
| j = -i - 2 | |||||
| setattr( | |||||
| self, 'ida_{}'.format(i), | |||||
| IDAUp(3, channels[j], in_channels[j:], | |||||
| scales[j:] // scales[j])) | |||||
| scales[j + 1:] = scales[j] | |||||
| in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]] | |||||
| def forward(self, layers): | |||||
| layers = list(layers) | |||||
| assert len(layers) > 1 | |||||
| for i in range(len(layers) - 1): | |||||
| ida = getattr(self, 'ida_{}'.format(i)) | |||||
| x, y = ida(layers[-i - 2:]) | |||||
| layers[-i - 1:] = y | |||||
| return x | |||||
| def fill_fc_weights(layers): | |||||
| for m in layers.modules(): | |||||
| if isinstance(m, nn.Conv2d): | |||||
| nn.init.normal_(m.weight, std=0.001) | |||||
| if m.bias is not None: | |||||
| nn.init.constant_(m.bias, 0) | |||||
| class DLASeg(nn.Module): | |||||
| def __init__(self, | |||||
| base_name='dla34', | |||||
| pretrained=False, | |||||
| down_ratio=4, | |||||
| head_conv=256): | |||||
| super(DLASeg, self).__init__() | |||||
| assert down_ratio in [2, 4, 8, 16] | |||||
| self.heads = {'hm': 2, 'v2c': 8, 'c2v': 8, 'reg': 2} | |||||
| self.first_level = int(np.log2(down_ratio)) | |||||
| self.base = globals()[base_name]( | |||||
| pretrained=pretrained, return_levels=True) | |||||
| channels = self.base.channels | |||||
| scales = [2**i for i in range(len(channels[self.first_level:]))] | |||||
| self.dla_up = DLAUp(channels[self.first_level:], scales=scales) | |||||
| for head in self.heads: | |||||
| classes = self.heads[head] | |||||
| if head_conv > 0: | |||||
| fc = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| channels[self.first_level], | |||||
| head_conv, | |||||
| kernel_size=3, | |||||
| padding=1, | |||||
| bias=True), nn.ReLU(inplace=True), | |||||
| nn.Conv2d( | |||||
| head_conv, | |||||
| classes, | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| padding=0, | |||||
| bias=True)) | |||||
| if 'hm' in head: | |||||
| fc[-1].bias.data.fill_(-2.19) | |||||
| else: | |||||
| fill_fc_weights(fc) | |||||
| else: | |||||
| fc = nn.Conv2d( | |||||
| channels[self.first_level], | |||||
| classes, | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| padding=0, | |||||
| bias=True) | |||||
| if 'hm' in head: | |||||
| fc.bias.data.fill_(-2.19) | |||||
| else: | |||||
| fill_fc_weights(fc) | |||||
| self.__setattr__(head, fc) | |||||
| def forward(self, x): | |||||
| x = self.base(x) | |||||
| x = self.dla_up(x[self.first_level:]) | |||||
| ret = {} | |||||
| for head in self.heads: | |||||
| ret[head] = self.__getattr__(head)(x) | |||||
| return [ret] | |||||
| def TableRecModel(): | |||||
| model = DLASeg() | |||||
| return model | |||||
| @@ -0,0 +1,315 @@ | |||||
| # ------------------------------------------------------------------------------ | |||||
| # The implementation is adopted from CenterNet, | |||||
| # made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git | |||||
| # ------------------------------------------------------------------------------ | |||||
| import copy | |||||
| import math | |||||
| import random | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| def transform_preds(coords, center, scale, output_size, rot=0): | |||||
| target_coords = np.zeros(coords.shape) | |||||
| trans = get_affine_transform(center, scale, rot, output_size, inv=1) | |||||
| for p in range(coords.shape[0]): | |||||
| target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) | |||||
| return target_coords | |||||
| def get_affine_transform(center, | |||||
| scale, | |||||
| rot, | |||||
| output_size, | |||||
| shift=np.array([0, 0], dtype=np.float32), | |||||
| inv=0): | |||||
| if not isinstance(scale, np.ndarray) and not isinstance(scale, list): | |||||
| scale = np.array([scale, scale], dtype=np.float32) | |||||
| scale_tmp = scale | |||||
| src_w = scale_tmp[0] | |||||
| dst_w = output_size[0] | |||||
| dst_h = output_size[1] | |||||
| rot_rad = np.pi * rot / 180 | |||||
| src_dir = get_dir([0, src_w * -0.5], rot_rad) | |||||
| dst_dir = np.array([0, dst_w * -0.5], np.float32) | |||||
| src = np.zeros((3, 2), dtype=np.float32) | |||||
| dst = np.zeros((3, 2), dtype=np.float32) | |||||
| src[0, :] = center + scale_tmp * shift | |||||
| src[1, :] = center + src_dir + scale_tmp * shift | |||||
| dst[0, :] = [dst_w * 0.5, dst_h * 0.5] | |||||
| dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir | |||||
| src[2:, :] = get_3rd_point(src[0, :], src[1, :]) | |||||
| dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) | |||||
| if inv: | |||||
| trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) | |||||
| else: | |||||
| trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) | |||||
| return trans | |||||
| def affine_transform(pt, t): | |||||
| new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32).T | |||||
| new_pt = np.dot(t, new_pt) | |||||
| return new_pt[:2] | |||||
| def get_dir(src_point, rot_rad): | |||||
| sn, cs = np.sin(rot_rad), np.cos(rot_rad) | |||||
| src_result = [0, 0] | |||||
| src_result[0] = src_point[0] * cs - src_point[1] * sn | |||||
| src_result[1] = src_point[0] * sn + src_point[1] * cs | |||||
| return src_result | |||||
| def get_3rd_point(a, b): | |||||
| direct = a - b | |||||
| return b + np.array([-direct[1], direct[0]], dtype=np.float32) | |||||
| def _sigmoid(x): | |||||
| y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4) | |||||
| return y | |||||
| def _gather_feat(feat, ind, mask=None): | |||||
| dim = feat.size(2) | |||||
| ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) | |||||
| feat = feat.gather(1, ind) | |||||
| if mask is not None: | |||||
| mask = mask.unsqueeze(2).expand_as(feat) | |||||
| feat = feat[mask] | |||||
| feat = feat.view(-1, dim) | |||||
| return feat | |||||
| def _tranpose_and_gather_feat(feat, ind): | |||||
| feat = feat.permute(0, 2, 3, 1).contiguous() | |||||
| feat = feat.view(feat.size(0), -1, feat.size(3)) | |||||
| feat = _gather_feat(feat, ind) | |||||
| return feat | |||||
| def _nms(heat, kernel=3): | |||||
| pad = (kernel - 1) // 2 | |||||
| hmax = nn.functional.max_pool2d( | |||||
| heat, (kernel, kernel), stride=1, padding=pad) | |||||
| keep = (hmax == heat).float() | |||||
| return heat * keep, keep | |||||
| def _topk(scores, K=40): | |||||
| batch, cat, height, width = scores.size() | |||||
| topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K) | |||||
| topk_inds = topk_inds % (height * width) | |||||
| topk_ys = (topk_inds / width).int().float() | |||||
| topk_xs = (topk_inds % width).int().float() | |||||
| topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) | |||||
| topk_clses = (topk_ind / K).int() | |||||
| topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), | |||||
| topk_ind).view(batch, K) | |||||
| topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K) | |||||
| topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K) | |||||
| return topk_score, topk_inds, topk_clses, topk_ys, topk_xs | |||||
| def bbox_decode(heat, wh, reg=None, K=100): | |||||
| batch, cat, height, width = heat.size() | |||||
| heat, keep = _nms(heat) | |||||
| scores, inds, clses, ys, xs = _topk(heat, K=K) | |||||
| if reg is not None: | |||||
| reg = _tranpose_and_gather_feat(reg, inds) | |||||
| reg = reg.view(batch, K, 2) | |||||
| xs = xs.view(batch, K, 1) + reg[:, :, 0:1] | |||||
| ys = ys.view(batch, K, 1) + reg[:, :, 1:2] | |||||
| else: | |||||
| xs = xs.view(batch, K, 1) + 0.5 | |||||
| ys = ys.view(batch, K, 1) + 0.5 | |||||
| wh = _tranpose_and_gather_feat(wh, inds) | |||||
| wh = wh.view(batch, K, 8) | |||||
| clses = clses.view(batch, K, 1).float() | |||||
| scores = scores.view(batch, K, 1) | |||||
| bboxes = torch.cat( | |||||
| [ | |||||
| xs - wh[..., 0:1], | |||||
| ys - wh[..., 1:2], | |||||
| xs - wh[..., 2:3], | |||||
| ys - wh[..., 3:4], | |||||
| xs - wh[..., 4:5], | |||||
| ys - wh[..., 5:6], | |||||
| xs - wh[..., 6:7], | |||||
| ys - wh[..., 7:8], | |||||
| ], | |||||
| dim=2, | |||||
| ) | |||||
| detections = torch.cat([bboxes, scores, clses], dim=2) | |||||
| return detections, keep | |||||
| def gbox_decode(mk, st_reg, reg=None, K=400): | |||||
| batch, cat, height, width = mk.size() | |||||
| mk, keep = _nms(mk) | |||||
| scores, inds, clses, ys, xs = _topk(mk, K=K) | |||||
| if reg is not None: | |||||
| reg = _tranpose_and_gather_feat(reg, inds) | |||||
| reg = reg.view(batch, K, 2) | |||||
| xs = xs.view(batch, K, 1) + reg[:, :, 0:1] | |||||
| ys = ys.view(batch, K, 1) + reg[:, :, 1:2] | |||||
| else: | |||||
| xs = xs.view(batch, K, 1) + 0.5 | |||||
| ys = ys.view(batch, K, 1) + 0.5 | |||||
| scores = scores.view(batch, K, 1) | |||||
| clses = clses.view(batch, K, 1).float() | |||||
| st_Reg = _tranpose_and_gather_feat(st_reg, inds) | |||||
| bboxes = torch.cat( | |||||
| [ | |||||
| xs - st_Reg[..., 0:1], | |||||
| ys - st_Reg[..., 1:2], | |||||
| xs - st_Reg[..., 2:3], | |||||
| ys - st_Reg[..., 3:4], | |||||
| xs - st_Reg[..., 4:5], | |||||
| ys - st_Reg[..., 5:6], | |||||
| xs - st_Reg[..., 6:7], | |||||
| ys - st_Reg[..., 7:8], | |||||
| ], | |||||
| dim=2, | |||||
| ) | |||||
| return torch.cat([xs, ys, bboxes, scores, clses], dim=2), keep | |||||
| def bbox_post_process(bbox, c, s, h, w): | |||||
| for i in range(bbox.shape[0]): | |||||
| bbox[i, :, 0:2] = transform_preds(bbox[i, :, 0:2], c[i], s[i], (w, h)) | |||||
| bbox[i, :, 2:4] = transform_preds(bbox[i, :, 2:4], c[i], s[i], (w, h)) | |||||
| bbox[i, :, 4:6] = transform_preds(bbox[i, :, 4:6], c[i], s[i], (w, h)) | |||||
| bbox[i, :, 6:8] = transform_preds(bbox[i, :, 6:8], c[i], s[i], (w, h)) | |||||
| return bbox | |||||
| def gbox_post_process(gbox, c, s, h, w): | |||||
| for i in range(gbox.shape[0]): | |||||
| gbox[i, :, 0:2] = transform_preds(gbox[i, :, 0:2], c[i], s[i], (w, h)) | |||||
| gbox[i, :, 2:4] = transform_preds(gbox[i, :, 2:4], c[i], s[i], (w, h)) | |||||
| gbox[i, :, 4:6] = transform_preds(gbox[i, :, 4:6], c[i], s[i], (w, h)) | |||||
| gbox[i, :, 6:8] = transform_preds(gbox[i, :, 6:8], c[i], s[i], (w, h)) | |||||
| gbox[i, :, 8:10] = transform_preds(gbox[i, :, 8:10], c[i], s[i], | |||||
| (w, h)) | |||||
| return gbox | |||||
| def nms(dets, thresh): | |||||
| if len(dets) < 2: | |||||
| return dets | |||||
| index_keep = [] | |||||
| keep = [] | |||||
| for i in range(len(dets)): | |||||
| box = dets[i] | |||||
| if box[-1] < thresh: | |||||
| break | |||||
| max_score_index = -1 | |||||
| ctx = (dets[i][0] + dets[i][2] + dets[i][4] + dets[i][6]) / 4 | |||||
| cty = (dets[i][1] + dets[i][3] + dets[i][5] + dets[i][7]) / 4 | |||||
| for j in range(len(dets)): | |||||
| if i == j or dets[j][-1] < thresh: | |||||
| break | |||||
| x1, y1 = dets[j][0], dets[j][1] | |||||
| x2, y2 = dets[j][2], dets[j][3] | |||||
| x3, y3 = dets[j][4], dets[j][5] | |||||
| x4, y4 = dets[j][6], dets[j][7] | |||||
| a = (x2 - x1) * (cty - y1) - (y2 - y1) * (ctx - x1) | |||||
| b = (x3 - x2) * (cty - y2) - (y3 - y2) * (ctx - x2) | |||||
| c = (x4 - x3) * (cty - y3) - (y4 - y3) * (ctx - x3) | |||||
| d = (x1 - x4) * (cty - y4) - (y1 - y4) * (ctx - x4) | |||||
| if (a > 0 and b > 0 and c > 0 and d > 0) or (a < 0 and b < 0 | |||||
| and c < 0 and d < 0): | |||||
| if dets[i][8] > dets[j][8] and max_score_index < 0: | |||||
| max_score_index = i | |||||
| elif dets[i][8] < dets[j][8]: | |||||
| max_score_index = -2 | |||||
| break | |||||
| if max_score_index > -1: | |||||
| index_keep.append(max_score_index) | |||||
| elif max_score_index == -1: | |||||
| index_keep.append(i) | |||||
| for i in range(0, len(index_keep)): | |||||
| keep.append(dets[index_keep[i]]) | |||||
| return np.array(keep) | |||||
| def group_bbox_by_gbox(bboxes, | |||||
| gboxes, | |||||
| score_thred=0.3, | |||||
| v2c_dist_thred=2, | |||||
| c2v_dist_thred=0.5): | |||||
| def point_in_box(box, point): | |||||
| x1, y1, x2, y2 = box[0], box[1], box[2], box[3] | |||||
| x3, y3, x4, y4 = box[4], box[5], box[6], box[7] | |||||
| ctx, cty = point[0], point[1] | |||||
| a = (x2 - x1) * (cty - y1) - (y2 - y1) * (ctx - x1) | |||||
| b = (x3 - x2) * (cty - y2) - (y3 - y2) * (ctx - x2) | |||||
| c = (x4 - x3) * (cty - y3) - (y4 - y3) * (ctx - x3) | |||||
| d = (x1 - x4) * (cty - y4) - (y1 - y4) * (ctx - x4) | |||||
| if (a > 0 and b > 0 and c > 0 and d > 0) or (a < 0 and b < 0 and c < 0 | |||||
| and d < 0): | |||||
| return True | |||||
| else: | |||||
| return False | |||||
| def get_distance(pt1, pt2): | |||||
| return math.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0]) | |||||
| + (pt1[1] - pt2[1]) * (pt1[1] - pt2[1])) | |||||
| dets = copy.deepcopy(bboxes) | |||||
| sign = np.zeros((len(dets), 4)) | |||||
| for idx, gbox in enumerate(gboxes): # vertex x,y, gbox, score | |||||
| if gbox[10] < score_thred: | |||||
| break | |||||
| vertex = [gbox[0], gbox[1]] | |||||
| for i in range(0, 4): | |||||
| center = [gbox[2 * i + 2], gbox[2 * i + 3]] | |||||
| if get_distance(vertex, center) < v2c_dist_thred: | |||||
| continue | |||||
| for k, bbox in enumerate(dets): | |||||
| if bbox[8] < score_thred: | |||||
| break | |||||
| if sum(sign[k]) == 4: | |||||
| continue | |||||
| w = (abs(bbox[6] - bbox[0]) + abs(bbox[4] - bbox[2])) / 2 | |||||
| h = (abs(bbox[3] - bbox[1]) + abs(bbox[5] - bbox[7])) / 2 | |||||
| m = max(w, h) | |||||
| if point_in_box(bbox, center): | |||||
| min_dist, min_id = 1e4, -1 | |||||
| for j in range(0, 4): | |||||
| dist = get_distance(vertex, | |||||
| [bbox[2 * j], bbox[2 * j + 1]]) | |||||
| if dist < min_dist: | |||||
| min_dist = dist | |||||
| min_id = j | |||||
| if (min_id > -1 and min_dist < c2v_dist_thred * m | |||||
| and sign[k][min_id] == 0): | |||||
| bboxes[k][2 * min_id] = vertex[0] | |||||
| bboxes[k][2 * min_id + 1] = vertex[1] | |||||
| sign[k][min_id] = 1 | |||||
| return bboxes | |||||
| @@ -0,0 +1,119 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import math | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import PIL | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.pipelines.cv.ocr_utils.model_dla34 import TableRecModel | |||||
| from modelscope.pipelines.cv.ocr_utils.table_process import ( | |||||
| bbox_decode, bbox_post_process, gbox_decode, gbox_post_process, | |||||
| get_affine_transform, group_bbox_by_gbox, nms) | |||||
| from modelscope.preprocessors import load_image | |||||
| from modelscope.preprocessors.image import LoadImage | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.table_recognition, module_name=Pipelines.table_recognition) | |||||
| class TableRecognitionPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | |||||
| logger.info(f'loading model from {model_path}') | |||||
| self.K = 1000 | |||||
| self.MK = 4000 | |||||
| self.device = torch.device( | |||||
| 'cuda' if torch.cuda.is_available() else 'cpu') | |||||
| self.infer_model = TableRecModel().to(self.device) | |||||
| self.infer_model.eval() | |||||
| checkpoint = torch.load(model_path, map_location=self.device) | |||||
| if 'state_dict' in checkpoint: | |||||
| self.infer_model.load_state_dict(checkpoint['state_dict']) | |||||
| else: | |||||
| self.infer_model.load_state_dict(checkpoint) | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| img = LoadImage.convert_to_ndarray(input) | |||||
| mean = np.array([0.408, 0.447, 0.470], | |||||
| dtype=np.float32).reshape(1, 1, 3) | |||||
| std = np.array([0.289, 0.274, 0.278], | |||||
| dtype=np.float32).reshape(1, 1, 3) | |||||
| height, width = img.shape[0:2] | |||||
| inp_height, inp_width = 1024, 1024 | |||||
| c = np.array([width / 2., height / 2.], dtype=np.float32) | |||||
| s = max(height, width) * 1.0 | |||||
| trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height]) | |||||
| resized_image = cv2.resize(img, (width, height)) | |||||
| inp_image = cv2.warpAffine( | |||||
| resized_image, | |||||
| trans_input, (inp_width, inp_height), | |||||
| flags=cv2.INTER_LINEAR) | |||||
| inp_image = ((inp_image / 255. - mean) / std).astype(np.float32) | |||||
| images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height, | |||||
| inp_width) | |||||
| images = torch.from_numpy(images).to(self.device) | |||||
| meta = { | |||||
| 'c': c, | |||||
| 's': s, | |||||
| 'input_height': inp_height, | |||||
| 'input_width': inp_width, | |||||
| 'out_height': inp_height // 4, | |||||
| 'out_width': inp_width // 4 | |||||
| } | |||||
| result = {'img': images, 'meta': meta} | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| pred = self.infer_model(input['img']) | |||||
| return {'results': pred, 'meta': input['meta']} | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| output = inputs['results'][0] | |||||
| meta = inputs['meta'] | |||||
| hm = output['hm'].sigmoid_() | |||||
| v2c = output['v2c'] | |||||
| c2v = output['c2v'] | |||||
| reg = output['reg'] | |||||
| bbox, _ = bbox_decode(hm[:, 0:1, :, :], c2v, reg=reg, K=self.K) | |||||
| gbox, _ = gbox_decode(hm[:, 1:2, :, :], v2c, reg=reg, K=self.MK) | |||||
| bbox = bbox.detach().cpu().numpy() | |||||
| gbox = gbox.detach().cpu().numpy() | |||||
| bbox = nms(bbox, 0.3) | |||||
| bbox = bbox_post_process(bbox.copy(), [meta['c'].cpu().numpy()], | |||||
| [meta['s']], meta['out_height'], | |||||
| meta['out_width']) | |||||
| gbox = gbox_post_process(gbox.copy(), [meta['c'].cpu().numpy()], | |||||
| [meta['s']], meta['out_height'], | |||||
| meta['out_width']) | |||||
| bbox = group_bbox_by_gbox(bbox[0], gbox[0]) | |||||
| res = [] | |||||
| for box in bbox: | |||||
| if box[8] > 0.3: | |||||
| res.append(box[0:8]) | |||||
| result = {OutputKeys.POLYGONS: np.array(res)} | |||||
| return result | |||||
| @@ -16,6 +16,7 @@ class CVTasks(object): | |||||
| # ocr | # ocr | ||||
| ocr_detection = 'ocr-detection' | ocr_detection = 'ocr-detection' | ||||
| ocr_recognition = 'ocr-recognition' | ocr_recognition = 'ocr-recognition' | ||||
| table_recognition = 'table-recognition' | |||||
| # human face body related | # human face body related | ||||
| animal_recognition = 'animal-recognition' | animal_recognition = 'animal-recognition' | ||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TableRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/cv_dla34_table-structure-recognition_cycle-centernet' | |||||
| self.test_image = 'data/test/images/table_recognition.jpg' | |||||
| self.task = Tasks.table_recognition | |||||
| def pipeline_inference(self, pipe: Pipeline, input_location: str): | |||||
| result = pipe(input_location) | |||||
| print('table recognition results: ') | |||||
| print(result) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| table_recognition = pipeline( | |||||
| Tasks.table_recognition, model=self.model_id) | |||||
| self.pipeline_inference(table_recognition, self.test_image) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_modelhub_default_model(self): | |||||
| table_recognition = pipeline(Tasks.table_recognition) | |||||
| self.pipeline_inference(table_recognition, self.test_image) | |||||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||||
| def test_demo_compatibility(self): | |||||
| self.compatibility_check() | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -39,6 +39,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run | |||||
| - test_automatic_speech_recognition.py | - test_automatic_speech_recognition.py | ||||
| - test_image_matting.py | - test_image_matting.py | ||||
| - test_skin_retouching.py | - test_skin_retouching.py | ||||
| - test_table_recognition.py | |||||
| envs: | envs: | ||||
| default: # default env, case not in other env will in default, pytorch. | default: # default env, case not in other env will in default, pytorch. | ||||