Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10773667master^2
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:f4b7e23f02a35136707ac7862e0a8468797f239e89497351847cfacb2a9c24f6 | |||
| size 202112 | |||
| @@ -151,6 +151,7 @@ class Pipelines(object): | |||
| image_denoise = 'nafnet-image-denoise' | |||
| person_image_cartoon = 'unet-person-image-cartoon' | |||
| ocr_detection = 'resnet18-ocr-detection' | |||
| table_recognition = 'dla34-table-recognition' | |||
| action_recognition = 'TAdaConv_action-recognition' | |||
| animal_recognition = 'resnet101-animal-recognition' | |||
| general_recognition = 'resnet101-general-recognition' | |||
| @@ -59,6 +59,7 @@ TASK_OUTPUTS = { | |||
| # [x1, y1, x2, y2, x3, y3, x4, y4] | |||
| # } | |||
| Tasks.ocr_detection: [OutputKeys.POLYGONS], | |||
| Tasks.table_recognition: [OutputKeys.POLYGONS], | |||
| # ocr recognition result for single sample | |||
| # { | |||
| @@ -82,6 +82,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/cv_unet_person-image-cartoon_compound-models'), | |||
| Tasks.ocr_detection: (Pipelines.ocr_detection, | |||
| 'damo/cv_resnet18_ocr-detection-line-level_damo'), | |||
| Tasks.table_recognition: | |||
| (Pipelines.table_recognition, | |||
| 'damo/cv_dla34_table-structure-recognition_cycle-centernet'), | |||
| Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | |||
| Tasks.feature_extraction: (Pipelines.feature_extraction, | |||
| 'damo/pert_feature-extraction_base-test'), | |||
| @@ -41,6 +41,7 @@ if TYPE_CHECKING: | |||
| from .live_category_pipeline import LiveCategoryPipeline | |||
| from .ocr_detection_pipeline import OCRDetectionPipeline | |||
| from .ocr_recognition_pipeline import OCRRecognitionPipeline | |||
| from .table_recognition_pipeline import TableRecognitionPipeline | |||
| from .skin_retouching_pipeline import SkinRetouchingPipeline | |||
| from .tinynas_classification_pipeline import TinynasClassificationPipeline | |||
| from .video_category_pipeline import VideoCategoryPipeline | |||
| @@ -108,6 +109,7 @@ else: | |||
| 'image_inpainting_pipeline': ['ImageInpaintingPipeline'], | |||
| 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | |||
| 'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], | |||
| 'table_recognition_pipeline': ['TableRecognitionPipeline'], | |||
| 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | |||
| 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], | |||
| 'video_category_pipeline': ['VideoCategoryPipeline'], | |||
| @@ -0,0 +1,655 @@ | |||
| # ------------------------------------------------------------------------------ | |||
| # The implementation is adopted from CenterNet, | |||
| # made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git | |||
| # ------------------------------------------------------------------------------ | |||
| import math | |||
| from os.path import join | |||
| import numpy as np | |||
| import torch | |||
| from torch import nn | |||
| BatchNorm = nn.BatchNorm2d | |||
| class BasicBlock(nn.Module): | |||
| def __init__(self, inplanes, planes, stride=1, dilation=1): | |||
| super(BasicBlock, self).__init__() | |||
| self.conv1 = nn.Conv2d( | |||
| inplanes, | |||
| planes, | |||
| kernel_size=3, | |||
| stride=stride, | |||
| padding=dilation, | |||
| bias=False, | |||
| dilation=dilation) | |||
| self.bn1 = BatchNorm(planes) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.conv2 = nn.Conv2d( | |||
| planes, | |||
| planes, | |||
| kernel_size=3, | |||
| stride=1, | |||
| padding=dilation, | |||
| bias=False, | |||
| dilation=dilation) | |||
| self.bn2 = BatchNorm(planes) | |||
| self.stride = stride | |||
| def forward(self, x, residual=None): | |||
| if residual is None: | |||
| residual = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| out += residual | |||
| out = self.relu(out) | |||
| return out | |||
| class Bottleneck(nn.Module): | |||
| expansion = 2 | |||
| def __init__(self, inplanes, planes, stride=1, dilation=1): | |||
| super(Bottleneck, self).__init__() | |||
| expansion = Bottleneck.expansion | |||
| bottle_planes = planes // expansion | |||
| self.conv1 = nn.Conv2d( | |||
| inplanes, bottle_planes, kernel_size=1, bias=False) | |||
| self.bn1 = BatchNorm(bottle_planes) | |||
| self.conv2 = nn.Conv2d( | |||
| bottle_planes, | |||
| bottle_planes, | |||
| kernel_size=3, | |||
| stride=stride, | |||
| padding=dilation, | |||
| bias=False, | |||
| dilation=dilation) | |||
| self.bn2 = BatchNorm(bottle_planes) | |||
| self.conv3 = nn.Conv2d( | |||
| bottle_planes, planes, kernel_size=1, bias=False) | |||
| self.bn3 = BatchNorm(planes) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.stride = stride | |||
| def forward(self, x, residual=None): | |||
| if residual is None: | |||
| residual = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| out = self.relu(out) | |||
| out = self.conv3(out) | |||
| out = self.bn3(out) | |||
| out += residual | |||
| out = self.relu(out) | |||
| return out | |||
| class BottleneckX(nn.Module): | |||
| expansion = 2 | |||
| cardinality = 32 | |||
| def __init__(self, inplanes, planes, stride=1, dilation=1): | |||
| super(BottleneckX, self).__init__() | |||
| cardinality = BottleneckX.cardinality | |||
| bottle_planes = planes * cardinality // 32 | |||
| self.conv1 = nn.Conv2d( | |||
| inplanes, bottle_planes, kernel_size=1, bias=False) | |||
| self.bn1 = BatchNorm(bottle_planes) | |||
| self.conv2 = nn.Conv2d( | |||
| bottle_planes, | |||
| bottle_planes, | |||
| kernel_size=3, | |||
| stride=stride, | |||
| padding=dilation, | |||
| bias=False, | |||
| dilation=dilation, | |||
| groups=cardinality) | |||
| self.bn2 = BatchNorm(bottle_planes) | |||
| self.conv3 = nn.Conv2d( | |||
| bottle_planes, planes, kernel_size=1, bias=False) | |||
| self.bn3 = BatchNorm(planes) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.stride = stride | |||
| def forward(self, x, residual=None): | |||
| if residual is None: | |||
| residual = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| out = self.relu(out) | |||
| out = self.conv3(out) | |||
| out = self.bn3(out) | |||
| out += residual | |||
| out = self.relu(out) | |||
| return out | |||
| class Root(nn.Module): | |||
| def __init__(self, in_channels, out_channels, kernel_size, residual): | |||
| super(Root, self).__init__() | |||
| self.conv = nn.Conv2d( | |||
| in_channels, | |||
| out_channels, | |||
| 1, | |||
| stride=1, | |||
| bias=False, | |||
| padding=(kernel_size - 1) // 2) | |||
| self.bn = BatchNorm(out_channels) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.residual = residual | |||
| def forward(self, *x): | |||
| children = x | |||
| x = self.conv(torch.cat(x, 1)) | |||
| x = self.bn(x) | |||
| if self.residual: | |||
| x += children[0] | |||
| x = self.relu(x) | |||
| return x | |||
| class Tree(nn.Module): | |||
| def __init__(self, | |||
| levels, | |||
| block, | |||
| in_channels, | |||
| out_channels, | |||
| stride=1, | |||
| level_root=False, | |||
| root_dim=0, | |||
| root_kernel_size=1, | |||
| dilation=1, | |||
| root_residual=False): | |||
| super(Tree, self).__init__() | |||
| if root_dim == 0: | |||
| root_dim = 2 * out_channels | |||
| if level_root: | |||
| root_dim += in_channels | |||
| if levels == 1: | |||
| self.tree1 = block( | |||
| in_channels, out_channels, stride, dilation=dilation) | |||
| self.tree2 = block( | |||
| out_channels, out_channels, 1, dilation=dilation) | |||
| else: | |||
| self.tree1 = Tree( | |||
| levels - 1, | |||
| block, | |||
| in_channels, | |||
| out_channels, | |||
| stride, | |||
| root_dim=0, | |||
| root_kernel_size=root_kernel_size, | |||
| dilation=dilation, | |||
| root_residual=root_residual) | |||
| self.tree2 = Tree( | |||
| levels - 1, | |||
| block, | |||
| out_channels, | |||
| out_channels, | |||
| root_dim=root_dim + out_channels, | |||
| root_kernel_size=root_kernel_size, | |||
| dilation=dilation, | |||
| root_residual=root_residual) | |||
| if levels == 1: | |||
| self.root = Root(root_dim, out_channels, root_kernel_size, | |||
| root_residual) | |||
| self.level_root = level_root | |||
| self.root_dim = root_dim | |||
| self.downsample = None | |||
| self.project = None | |||
| self.levels = levels | |||
| if stride > 1: | |||
| self.downsample = nn.MaxPool2d(stride, stride=stride) | |||
| if in_channels != out_channels: | |||
| self.project = nn.Sequential( | |||
| nn.Conv2d( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size=1, | |||
| stride=1, | |||
| bias=False), BatchNorm(out_channels)) | |||
| def forward(self, x, residual=None, children=None): | |||
| children = [] if children is None else children | |||
| bottom = self.downsample(x) if self.downsample else x | |||
| residual = self.project(bottom) if self.project else bottom | |||
| if self.level_root: | |||
| children.append(bottom) | |||
| x1 = self.tree1(x, residual) | |||
| if self.levels == 1: | |||
| x2 = self.tree2(x1) | |||
| x = self.root(x2, x1, *children) | |||
| else: | |||
| children.append(x1) | |||
| x = self.tree2(x1, children=children) | |||
| return x | |||
| class DLA(nn.Module): | |||
| def __init__(self, | |||
| levels, | |||
| channels, | |||
| num_classes=1000, | |||
| block=BasicBlock, | |||
| residual_root=False, | |||
| return_levels=False, | |||
| pool_size=7, | |||
| linear_root=False): | |||
| super(DLA, self).__init__() | |||
| self.channels = channels | |||
| self.return_levels = return_levels | |||
| self.num_classes = num_classes | |||
| self.base_layer = nn.Sequential( | |||
| nn.Conv2d( | |||
| 3, channels[0], kernel_size=7, stride=1, padding=3, | |||
| bias=False), BatchNorm(channels[0]), nn.ReLU(inplace=True)) | |||
| self.level0 = self._make_conv_level(channels[0], channels[0], | |||
| levels[0]) | |||
| self.level1 = self._make_conv_level( | |||
| channels[0], channels[1], levels[1], stride=2) | |||
| self.level2 = Tree( | |||
| levels[2], | |||
| block, | |||
| channels[1], | |||
| channels[2], | |||
| 2, | |||
| level_root=False, | |||
| root_residual=residual_root) | |||
| self.level3 = Tree( | |||
| levels[3], | |||
| block, | |||
| channels[2], | |||
| channels[3], | |||
| 2, | |||
| level_root=True, | |||
| root_residual=residual_root) | |||
| self.level4 = Tree( | |||
| levels[4], | |||
| block, | |||
| channels[3], | |||
| channels[4], | |||
| 2, | |||
| level_root=True, | |||
| root_residual=residual_root) | |||
| self.level5 = Tree( | |||
| levels[5], | |||
| block, | |||
| channels[4], | |||
| channels[5], | |||
| 2, | |||
| level_root=True, | |||
| root_residual=residual_root) | |||
| self.avgpool = nn.AvgPool2d(pool_size) | |||
| self.fc = nn.Conv2d( | |||
| channels[-1], | |||
| num_classes, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| bias=True) | |||
| for m in self.modules(): | |||
| if isinstance(m, nn.Conv2d): | |||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||
| m.weight.data.normal_(0, math.sqrt(2. / n)) | |||
| elif isinstance(m, BatchNorm): | |||
| m.weight.data.fill_(1) | |||
| m.bias.data.zero_() | |||
| def _make_level(self, block, inplanes, planes, blocks, stride=1): | |||
| downsample = None | |||
| if stride != 1 or inplanes != planes: | |||
| downsample = nn.Sequential( | |||
| nn.MaxPool2d(stride, stride=stride), | |||
| nn.Conv2d( | |||
| inplanes, planes, kernel_size=1, stride=1, bias=False), | |||
| BatchNorm(planes), | |||
| ) | |||
| layers = [] | |||
| layers.append(block(inplanes, planes, stride, downsample=downsample)) | |||
| for i in range(1, blocks): | |||
| layers.append(block(inplanes, planes)) | |||
| return nn.Sequential(*layers) | |||
| def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): | |||
| modules = [] | |||
| for i in range(convs): | |||
| modules.extend([ | |||
| nn.Conv2d( | |||
| inplanes, | |||
| planes, | |||
| kernel_size=3, | |||
| stride=stride if i == 0 else 1, | |||
| padding=dilation, | |||
| bias=False, | |||
| dilation=dilation), | |||
| BatchNorm(planes), | |||
| nn.ReLU(inplace=True) | |||
| ]) | |||
| inplanes = planes | |||
| return nn.Sequential(*modules) | |||
| def forward(self, x): | |||
| y = [] | |||
| x = self.base_layer(x) | |||
| for i in range(6): | |||
| x = getattr(self, 'level{}'.format(i))(x) | |||
| y.append(x) | |||
| if self.return_levels: | |||
| return y | |||
| else: | |||
| x = self.avgpool(x) | |||
| x = self.fc(x) | |||
| x = x.view(x.size(0), -1) | |||
| return x | |||
| def dla34(pretrained, **kwargs): # DLA-34 | |||
| model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], | |||
| block=BasicBlock, | |||
| **kwargs) | |||
| return model | |||
| def dla46_c(pretrained=None, **kwargs): # DLA-46-C | |||
| Bottleneck.expansion = 2 | |||
| model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], | |||
| block=Bottleneck, | |||
| **kwargs) | |||
| return model | |||
| def dla46x_c(pretrained=None, **kwargs): # DLA-X-46-C | |||
| BottleneckX.expansion = 2 | |||
| model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], | |||
| block=BottleneckX, | |||
| **kwargs) | |||
| return model | |||
| def dla60x_c(pretrained, **kwargs): # DLA-X-60-C | |||
| BottleneckX.expansion = 2 | |||
| model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 64, 64, 128, 256], | |||
| block=BottleneckX, | |||
| **kwargs) | |||
| return model | |||
| def dla60(pretrained=None, **kwargs): # DLA-60 | |||
| Bottleneck.expansion = 2 | |||
| model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], | |||
| block=Bottleneck, | |||
| **kwargs) | |||
| return model | |||
| def dla60x(pretrained=None, **kwargs): # DLA-X-60 | |||
| BottleneckX.expansion = 2 | |||
| model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], | |||
| block=BottleneckX, | |||
| **kwargs) | |||
| return model | |||
| def dla102(pretrained=None, **kwargs): # DLA-102 | |||
| Bottleneck.expansion = 2 | |||
| model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], | |||
| block=Bottleneck, | |||
| residual_root=True, | |||
| **kwargs) | |||
| return model | |||
| def dla102x(pretrained=None, **kwargs): # DLA-X-102 | |||
| BottleneckX.expansion = 2 | |||
| model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], | |||
| block=BottleneckX, | |||
| residual_root=True, | |||
| **kwargs) | |||
| return model | |||
| def dla102x2(pretrained=None, **kwargs): # DLA-X-102 64 | |||
| BottleneckX.cardinality = 64 | |||
| model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], | |||
| block=BottleneckX, | |||
| residual_root=True, | |||
| **kwargs) | |||
| return model | |||
| def dla169(pretrained=None, **kwargs): # DLA-169 | |||
| Bottleneck.expansion = 2 | |||
| model = DLA([1, 1, 2, 3, 5, 1], [16, 32, 128, 256, 512, 1024], | |||
| block=Bottleneck, | |||
| residual_root=True, | |||
| **kwargs) | |||
| return model | |||
| def set_bn(bn): | |||
| global BatchNorm | |||
| BatchNorm = bn | |||
| dla.BatchNorm = bn | |||
| class Identity(nn.Module): | |||
| def __init__(self): | |||
| super(Identity, self).__init__() | |||
| def forward(self, x): | |||
| return x | |||
| def fill_up_weights(up): | |||
| w = up.weight.data | |||
| f = math.ceil(w.size(2) / 2) | |||
| c = (2 * f - 1 - f % 2) / (2. * f) | |||
| for i in range(w.size(2)): | |||
| for j in range(w.size(3)): | |||
| w[0, 0, i, j] = \ | |||
| (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) | |||
| for c in range(1, w.size(0)): | |||
| w[c, 0, :, :] = w[0, 0, :, :] | |||
| class IDAUp(nn.Module): | |||
| def __init__(self, node_kernel, out_dim, channels, up_factors): | |||
| super(IDAUp, self).__init__() | |||
| self.channels = channels | |||
| self.out_dim = out_dim | |||
| for i, c in enumerate(channels): | |||
| if c == out_dim: | |||
| proj = Identity() | |||
| else: | |||
| proj = nn.Sequential( | |||
| nn.Conv2d(c, out_dim, kernel_size=1, stride=1, bias=False), | |||
| BatchNorm(out_dim), nn.ReLU(inplace=True)) | |||
| f = int(up_factors[i]) | |||
| if f == 1: | |||
| up = Identity() | |||
| else: | |||
| up = nn.ConvTranspose2d( | |||
| out_dim, | |||
| out_dim, | |||
| f * 2, | |||
| stride=f, | |||
| padding=f // 2, | |||
| output_padding=0, | |||
| groups=out_dim, | |||
| bias=False) | |||
| fill_up_weights(up) | |||
| setattr(self, 'proj_' + str(i), proj) | |||
| setattr(self, 'up_' + str(i), up) | |||
| for i in range(1, len(channels)): | |||
| node = nn.Sequential( | |||
| nn.Conv2d( | |||
| out_dim * 2, | |||
| out_dim, | |||
| kernel_size=node_kernel, | |||
| stride=1, | |||
| padding=node_kernel // 2, | |||
| bias=False), BatchNorm(out_dim), nn.ReLU(inplace=True)) | |||
| setattr(self, 'node_' + str(i), node) | |||
| for m in self.modules(): | |||
| if isinstance(m, nn.Conv2d): | |||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||
| m.weight.data.normal_(0, math.sqrt(2. / n)) | |||
| elif isinstance(m, BatchNorm): | |||
| m.weight.data.fill_(1) | |||
| m.bias.data.zero_() | |||
| def forward(self, layers): | |||
| assert len(self.channels) == len(layers), \ | |||
| '{} vs {} layers'.format(len(self.channels), len(layers)) | |||
| layers = list(layers) | |||
| for i, l in enumerate(layers): | |||
| upsample = getattr(self, 'up_' + str(i)) | |||
| project = getattr(self, 'proj_' + str(i)) | |||
| layers[i] = upsample(project(l)) | |||
| x = layers[0] | |||
| y = [] | |||
| for i in range(1, len(layers)): | |||
| node = getattr(self, 'node_' + str(i)) | |||
| x = node(torch.cat([x, layers[i]], 1)) | |||
| y.append(x) | |||
| return x, y | |||
| class DLAUp(nn.Module): | |||
| def __init__(self, channels, scales=(1, 2, 4, 8, 16), in_channels=None): | |||
| super(DLAUp, self).__init__() | |||
| if in_channels is None: | |||
| in_channels = channels | |||
| self.channels = channels | |||
| channels = list(channels) | |||
| scales = np.array(scales, dtype=int) | |||
| for i in range(len(channels) - 1): | |||
| j = -i - 2 | |||
| setattr( | |||
| self, 'ida_{}'.format(i), | |||
| IDAUp(3, channels[j], in_channels[j:], | |||
| scales[j:] // scales[j])) | |||
| scales[j + 1:] = scales[j] | |||
| in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]] | |||
| def forward(self, layers): | |||
| layers = list(layers) | |||
| assert len(layers) > 1 | |||
| for i in range(len(layers) - 1): | |||
| ida = getattr(self, 'ida_{}'.format(i)) | |||
| x, y = ida(layers[-i - 2:]) | |||
| layers[-i - 1:] = y | |||
| return x | |||
| def fill_fc_weights(layers): | |||
| for m in layers.modules(): | |||
| if isinstance(m, nn.Conv2d): | |||
| nn.init.normal_(m.weight, std=0.001) | |||
| if m.bias is not None: | |||
| nn.init.constant_(m.bias, 0) | |||
| class DLASeg(nn.Module): | |||
| def __init__(self, | |||
| base_name='dla34', | |||
| pretrained=False, | |||
| down_ratio=4, | |||
| head_conv=256): | |||
| super(DLASeg, self).__init__() | |||
| assert down_ratio in [2, 4, 8, 16] | |||
| self.heads = {'hm': 2, 'v2c': 8, 'c2v': 8, 'reg': 2} | |||
| self.first_level = int(np.log2(down_ratio)) | |||
| self.base = globals()[base_name]( | |||
| pretrained=pretrained, return_levels=True) | |||
| channels = self.base.channels | |||
| scales = [2**i for i in range(len(channels[self.first_level:]))] | |||
| self.dla_up = DLAUp(channels[self.first_level:], scales=scales) | |||
| for head in self.heads: | |||
| classes = self.heads[head] | |||
| if head_conv > 0: | |||
| fc = nn.Sequential( | |||
| nn.Conv2d( | |||
| channels[self.first_level], | |||
| head_conv, | |||
| kernel_size=3, | |||
| padding=1, | |||
| bias=True), nn.ReLU(inplace=True), | |||
| nn.Conv2d( | |||
| head_conv, | |||
| classes, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| bias=True)) | |||
| if 'hm' in head: | |||
| fc[-1].bias.data.fill_(-2.19) | |||
| else: | |||
| fill_fc_weights(fc) | |||
| else: | |||
| fc = nn.Conv2d( | |||
| channels[self.first_level], | |||
| classes, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| bias=True) | |||
| if 'hm' in head: | |||
| fc.bias.data.fill_(-2.19) | |||
| else: | |||
| fill_fc_weights(fc) | |||
| self.__setattr__(head, fc) | |||
| def forward(self, x): | |||
| x = self.base(x) | |||
| x = self.dla_up(x[self.first_level:]) | |||
| ret = {} | |||
| for head in self.heads: | |||
| ret[head] = self.__getattr__(head)(x) | |||
| return [ret] | |||
| def TableRecModel(): | |||
| model = DLASeg() | |||
| return model | |||
| @@ -0,0 +1,315 @@ | |||
| # ------------------------------------------------------------------------------ | |||
| # The implementation is adopted from CenterNet, | |||
| # made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git | |||
| # ------------------------------------------------------------------------------ | |||
| import copy | |||
| import math | |||
| import random | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| def transform_preds(coords, center, scale, output_size, rot=0): | |||
| target_coords = np.zeros(coords.shape) | |||
| trans = get_affine_transform(center, scale, rot, output_size, inv=1) | |||
| for p in range(coords.shape[0]): | |||
| target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) | |||
| return target_coords | |||
| def get_affine_transform(center, | |||
| scale, | |||
| rot, | |||
| output_size, | |||
| shift=np.array([0, 0], dtype=np.float32), | |||
| inv=0): | |||
| if not isinstance(scale, np.ndarray) and not isinstance(scale, list): | |||
| scale = np.array([scale, scale], dtype=np.float32) | |||
| scale_tmp = scale | |||
| src_w = scale_tmp[0] | |||
| dst_w = output_size[0] | |||
| dst_h = output_size[1] | |||
| rot_rad = np.pi * rot / 180 | |||
| src_dir = get_dir([0, src_w * -0.5], rot_rad) | |||
| dst_dir = np.array([0, dst_w * -0.5], np.float32) | |||
| src = np.zeros((3, 2), dtype=np.float32) | |||
| dst = np.zeros((3, 2), dtype=np.float32) | |||
| src[0, :] = center + scale_tmp * shift | |||
| src[1, :] = center + src_dir + scale_tmp * shift | |||
| dst[0, :] = [dst_w * 0.5, dst_h * 0.5] | |||
| dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir | |||
| src[2:, :] = get_3rd_point(src[0, :], src[1, :]) | |||
| dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) | |||
| if inv: | |||
| trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) | |||
| else: | |||
| trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) | |||
| return trans | |||
| def affine_transform(pt, t): | |||
| new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32).T | |||
| new_pt = np.dot(t, new_pt) | |||
| return new_pt[:2] | |||
| def get_dir(src_point, rot_rad): | |||
| sn, cs = np.sin(rot_rad), np.cos(rot_rad) | |||
| src_result = [0, 0] | |||
| src_result[0] = src_point[0] * cs - src_point[1] * sn | |||
| src_result[1] = src_point[0] * sn + src_point[1] * cs | |||
| return src_result | |||
| def get_3rd_point(a, b): | |||
| direct = a - b | |||
| return b + np.array([-direct[1], direct[0]], dtype=np.float32) | |||
| def _sigmoid(x): | |||
| y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4) | |||
| return y | |||
| def _gather_feat(feat, ind, mask=None): | |||
| dim = feat.size(2) | |||
| ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) | |||
| feat = feat.gather(1, ind) | |||
| if mask is not None: | |||
| mask = mask.unsqueeze(2).expand_as(feat) | |||
| feat = feat[mask] | |||
| feat = feat.view(-1, dim) | |||
| return feat | |||
| def _tranpose_and_gather_feat(feat, ind): | |||
| feat = feat.permute(0, 2, 3, 1).contiguous() | |||
| feat = feat.view(feat.size(0), -1, feat.size(3)) | |||
| feat = _gather_feat(feat, ind) | |||
| return feat | |||
| def _nms(heat, kernel=3): | |||
| pad = (kernel - 1) // 2 | |||
| hmax = nn.functional.max_pool2d( | |||
| heat, (kernel, kernel), stride=1, padding=pad) | |||
| keep = (hmax == heat).float() | |||
| return heat * keep, keep | |||
| def _topk(scores, K=40): | |||
| batch, cat, height, width = scores.size() | |||
| topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K) | |||
| topk_inds = topk_inds % (height * width) | |||
| topk_ys = (topk_inds / width).int().float() | |||
| topk_xs = (topk_inds % width).int().float() | |||
| topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) | |||
| topk_clses = (topk_ind / K).int() | |||
| topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), | |||
| topk_ind).view(batch, K) | |||
| topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K) | |||
| topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K) | |||
| return topk_score, topk_inds, topk_clses, topk_ys, topk_xs | |||
| def bbox_decode(heat, wh, reg=None, K=100): | |||
| batch, cat, height, width = heat.size() | |||
| heat, keep = _nms(heat) | |||
| scores, inds, clses, ys, xs = _topk(heat, K=K) | |||
| if reg is not None: | |||
| reg = _tranpose_and_gather_feat(reg, inds) | |||
| reg = reg.view(batch, K, 2) | |||
| xs = xs.view(batch, K, 1) + reg[:, :, 0:1] | |||
| ys = ys.view(batch, K, 1) + reg[:, :, 1:2] | |||
| else: | |||
| xs = xs.view(batch, K, 1) + 0.5 | |||
| ys = ys.view(batch, K, 1) + 0.5 | |||
| wh = _tranpose_and_gather_feat(wh, inds) | |||
| wh = wh.view(batch, K, 8) | |||
| clses = clses.view(batch, K, 1).float() | |||
| scores = scores.view(batch, K, 1) | |||
| bboxes = torch.cat( | |||
| [ | |||
| xs - wh[..., 0:1], | |||
| ys - wh[..., 1:2], | |||
| xs - wh[..., 2:3], | |||
| ys - wh[..., 3:4], | |||
| xs - wh[..., 4:5], | |||
| ys - wh[..., 5:6], | |||
| xs - wh[..., 6:7], | |||
| ys - wh[..., 7:8], | |||
| ], | |||
| dim=2, | |||
| ) | |||
| detections = torch.cat([bboxes, scores, clses], dim=2) | |||
| return detections, keep | |||
| def gbox_decode(mk, st_reg, reg=None, K=400): | |||
| batch, cat, height, width = mk.size() | |||
| mk, keep = _nms(mk) | |||
| scores, inds, clses, ys, xs = _topk(mk, K=K) | |||
| if reg is not None: | |||
| reg = _tranpose_and_gather_feat(reg, inds) | |||
| reg = reg.view(batch, K, 2) | |||
| xs = xs.view(batch, K, 1) + reg[:, :, 0:1] | |||
| ys = ys.view(batch, K, 1) + reg[:, :, 1:2] | |||
| else: | |||
| xs = xs.view(batch, K, 1) + 0.5 | |||
| ys = ys.view(batch, K, 1) + 0.5 | |||
| scores = scores.view(batch, K, 1) | |||
| clses = clses.view(batch, K, 1).float() | |||
| st_Reg = _tranpose_and_gather_feat(st_reg, inds) | |||
| bboxes = torch.cat( | |||
| [ | |||
| xs - st_Reg[..., 0:1], | |||
| ys - st_Reg[..., 1:2], | |||
| xs - st_Reg[..., 2:3], | |||
| ys - st_Reg[..., 3:4], | |||
| xs - st_Reg[..., 4:5], | |||
| ys - st_Reg[..., 5:6], | |||
| xs - st_Reg[..., 6:7], | |||
| ys - st_Reg[..., 7:8], | |||
| ], | |||
| dim=2, | |||
| ) | |||
| return torch.cat([xs, ys, bboxes, scores, clses], dim=2), keep | |||
| def bbox_post_process(bbox, c, s, h, w): | |||
| for i in range(bbox.shape[0]): | |||
| bbox[i, :, 0:2] = transform_preds(bbox[i, :, 0:2], c[i], s[i], (w, h)) | |||
| bbox[i, :, 2:4] = transform_preds(bbox[i, :, 2:4], c[i], s[i], (w, h)) | |||
| bbox[i, :, 4:6] = transform_preds(bbox[i, :, 4:6], c[i], s[i], (w, h)) | |||
| bbox[i, :, 6:8] = transform_preds(bbox[i, :, 6:8], c[i], s[i], (w, h)) | |||
| return bbox | |||
| def gbox_post_process(gbox, c, s, h, w): | |||
| for i in range(gbox.shape[0]): | |||
| gbox[i, :, 0:2] = transform_preds(gbox[i, :, 0:2], c[i], s[i], (w, h)) | |||
| gbox[i, :, 2:4] = transform_preds(gbox[i, :, 2:4], c[i], s[i], (w, h)) | |||
| gbox[i, :, 4:6] = transform_preds(gbox[i, :, 4:6], c[i], s[i], (w, h)) | |||
| gbox[i, :, 6:8] = transform_preds(gbox[i, :, 6:8], c[i], s[i], (w, h)) | |||
| gbox[i, :, 8:10] = transform_preds(gbox[i, :, 8:10], c[i], s[i], | |||
| (w, h)) | |||
| return gbox | |||
| def nms(dets, thresh): | |||
| if len(dets) < 2: | |||
| return dets | |||
| index_keep = [] | |||
| keep = [] | |||
| for i in range(len(dets)): | |||
| box = dets[i] | |||
| if box[-1] < thresh: | |||
| break | |||
| max_score_index = -1 | |||
| ctx = (dets[i][0] + dets[i][2] + dets[i][4] + dets[i][6]) / 4 | |||
| cty = (dets[i][1] + dets[i][3] + dets[i][5] + dets[i][7]) / 4 | |||
| for j in range(len(dets)): | |||
| if i == j or dets[j][-1] < thresh: | |||
| break | |||
| x1, y1 = dets[j][0], dets[j][1] | |||
| x2, y2 = dets[j][2], dets[j][3] | |||
| x3, y3 = dets[j][4], dets[j][5] | |||
| x4, y4 = dets[j][6], dets[j][7] | |||
| a = (x2 - x1) * (cty - y1) - (y2 - y1) * (ctx - x1) | |||
| b = (x3 - x2) * (cty - y2) - (y3 - y2) * (ctx - x2) | |||
| c = (x4 - x3) * (cty - y3) - (y4 - y3) * (ctx - x3) | |||
| d = (x1 - x4) * (cty - y4) - (y1 - y4) * (ctx - x4) | |||
| if (a > 0 and b > 0 and c > 0 and d > 0) or (a < 0 and b < 0 | |||
| and c < 0 and d < 0): | |||
| if dets[i][8] > dets[j][8] and max_score_index < 0: | |||
| max_score_index = i | |||
| elif dets[i][8] < dets[j][8]: | |||
| max_score_index = -2 | |||
| break | |||
| if max_score_index > -1: | |||
| index_keep.append(max_score_index) | |||
| elif max_score_index == -1: | |||
| index_keep.append(i) | |||
| for i in range(0, len(index_keep)): | |||
| keep.append(dets[index_keep[i]]) | |||
| return np.array(keep) | |||
| def group_bbox_by_gbox(bboxes, | |||
| gboxes, | |||
| score_thred=0.3, | |||
| v2c_dist_thred=2, | |||
| c2v_dist_thred=0.5): | |||
| def point_in_box(box, point): | |||
| x1, y1, x2, y2 = box[0], box[1], box[2], box[3] | |||
| x3, y3, x4, y4 = box[4], box[5], box[6], box[7] | |||
| ctx, cty = point[0], point[1] | |||
| a = (x2 - x1) * (cty - y1) - (y2 - y1) * (ctx - x1) | |||
| b = (x3 - x2) * (cty - y2) - (y3 - y2) * (ctx - x2) | |||
| c = (x4 - x3) * (cty - y3) - (y4 - y3) * (ctx - x3) | |||
| d = (x1 - x4) * (cty - y4) - (y1 - y4) * (ctx - x4) | |||
| if (a > 0 and b > 0 and c > 0 and d > 0) or (a < 0 and b < 0 and c < 0 | |||
| and d < 0): | |||
| return True | |||
| else: | |||
| return False | |||
| def get_distance(pt1, pt2): | |||
| return math.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0]) | |||
| + (pt1[1] - pt2[1]) * (pt1[1] - pt2[1])) | |||
| dets = copy.deepcopy(bboxes) | |||
| sign = np.zeros((len(dets), 4)) | |||
| for idx, gbox in enumerate(gboxes): # vertex x,y, gbox, score | |||
| if gbox[10] < score_thred: | |||
| break | |||
| vertex = [gbox[0], gbox[1]] | |||
| for i in range(0, 4): | |||
| center = [gbox[2 * i + 2], gbox[2 * i + 3]] | |||
| if get_distance(vertex, center) < v2c_dist_thred: | |||
| continue | |||
| for k, bbox in enumerate(dets): | |||
| if bbox[8] < score_thred: | |||
| break | |||
| if sum(sign[k]) == 4: | |||
| continue | |||
| w = (abs(bbox[6] - bbox[0]) + abs(bbox[4] - bbox[2])) / 2 | |||
| h = (abs(bbox[3] - bbox[1]) + abs(bbox[5] - bbox[7])) / 2 | |||
| m = max(w, h) | |||
| if point_in_box(bbox, center): | |||
| min_dist, min_id = 1e4, -1 | |||
| for j in range(0, 4): | |||
| dist = get_distance(vertex, | |||
| [bbox[2 * j], bbox[2 * j + 1]]) | |||
| if dist < min_dist: | |||
| min_dist = dist | |||
| min_id = j | |||
| if (min_id > -1 and min_dist < c2v_dist_thred * m | |||
| and sign[k][min_id] == 0): | |||
| bboxes[k][2 * min_id] = vertex[0] | |||
| bboxes[k][2 * min_id + 1] = vertex[1] | |||
| sign[k][min_id] = 1 | |||
| return bboxes | |||
| @@ -0,0 +1,119 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import math | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.pipelines.cv.ocr_utils.model_dla34 import TableRecModel | |||
| from modelscope.pipelines.cv.ocr_utils.table_process import ( | |||
| bbox_decode, bbox_post_process, gbox_decode, gbox_post_process, | |||
| get_affine_transform, group_bbox_by_gbox, nms) | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.preprocessors.image import LoadImage | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.table_recognition, module_name=Pipelines.table_recognition) | |||
| class TableRecognitionPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | |||
| logger.info(f'loading model from {model_path}') | |||
| self.K = 1000 | |||
| self.MK = 4000 | |||
| self.device = torch.device( | |||
| 'cuda' if torch.cuda.is_available() else 'cpu') | |||
| self.infer_model = TableRecModel().to(self.device) | |||
| self.infer_model.eval() | |||
| checkpoint = torch.load(model_path, map_location=self.device) | |||
| if 'state_dict' in checkpoint: | |||
| self.infer_model.load_state_dict(checkpoint['state_dict']) | |||
| else: | |||
| self.infer_model.load_state_dict(checkpoint) | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| img = LoadImage.convert_to_ndarray(input) | |||
| mean = np.array([0.408, 0.447, 0.470], | |||
| dtype=np.float32).reshape(1, 1, 3) | |||
| std = np.array([0.289, 0.274, 0.278], | |||
| dtype=np.float32).reshape(1, 1, 3) | |||
| height, width = img.shape[0:2] | |||
| inp_height, inp_width = 1024, 1024 | |||
| c = np.array([width / 2., height / 2.], dtype=np.float32) | |||
| s = max(height, width) * 1.0 | |||
| trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height]) | |||
| resized_image = cv2.resize(img, (width, height)) | |||
| inp_image = cv2.warpAffine( | |||
| resized_image, | |||
| trans_input, (inp_width, inp_height), | |||
| flags=cv2.INTER_LINEAR) | |||
| inp_image = ((inp_image / 255. - mean) / std).astype(np.float32) | |||
| images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height, | |||
| inp_width) | |||
| images = torch.from_numpy(images).to(self.device) | |||
| meta = { | |||
| 'c': c, | |||
| 's': s, | |||
| 'input_height': inp_height, | |||
| 'input_width': inp_width, | |||
| 'out_height': inp_height // 4, | |||
| 'out_width': inp_width // 4 | |||
| } | |||
| result = {'img': images, 'meta': meta} | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| pred = self.infer_model(input['img']) | |||
| return {'results': pred, 'meta': input['meta']} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| output = inputs['results'][0] | |||
| meta = inputs['meta'] | |||
| hm = output['hm'].sigmoid_() | |||
| v2c = output['v2c'] | |||
| c2v = output['c2v'] | |||
| reg = output['reg'] | |||
| bbox, _ = bbox_decode(hm[:, 0:1, :, :], c2v, reg=reg, K=self.K) | |||
| gbox, _ = gbox_decode(hm[:, 1:2, :, :], v2c, reg=reg, K=self.MK) | |||
| bbox = bbox.detach().cpu().numpy() | |||
| gbox = gbox.detach().cpu().numpy() | |||
| bbox = nms(bbox, 0.3) | |||
| bbox = bbox_post_process(bbox.copy(), [meta['c'].cpu().numpy()], | |||
| [meta['s']], meta['out_height'], | |||
| meta['out_width']) | |||
| gbox = gbox_post_process(gbox.copy(), [meta['c'].cpu().numpy()], | |||
| [meta['s']], meta['out_height'], | |||
| meta['out_width']) | |||
| bbox = group_bbox_by_gbox(bbox[0], gbox[0]) | |||
| res = [] | |||
| for box in bbox: | |||
| if box[8] > 0.3: | |||
| res.append(box[0:8]) | |||
| result = {OutputKeys.POLYGONS: np.array(res)} | |||
| return result | |||
| @@ -16,6 +16,7 @@ class CVTasks(object): | |||
| # ocr | |||
| ocr_detection = 'ocr-detection' | |||
| ocr_recognition = 'ocr-recognition' | |||
| table_recognition = 'table-recognition' | |||
| # human face body related | |||
| animal_recognition = 'animal-recognition' | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| class TableRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_dla34_table-structure-recognition_cycle-centernet' | |||
| self.test_image = 'data/test/images/table_recognition.jpg' | |||
| self.task = Tasks.table_recognition | |||
| def pipeline_inference(self, pipe: Pipeline, input_location: str): | |||
| result = pipe(input_location) | |||
| print('table recognition results: ') | |||
| print(result) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| table_recognition = pipeline( | |||
| Tasks.table_recognition, model=self.model_id) | |||
| self.pipeline_inference(table_recognition, self.test_image) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_modelhub_default_model(self): | |||
| table_recognition = pipeline(Tasks.table_recognition) | |||
| self.pipeline_inference(table_recognition, self.test_image) | |||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
| def test_demo_compatibility(self): | |||
| self.compatibility_check() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -39,6 +39,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run | |||
| - test_automatic_speech_recognition.py | |||
| - test_image_matting.py | |||
| - test_skin_retouching.py | |||
| - test_table_recognition.py | |||
| envs: | |||
| default: # default env, case not in other env will in default, pytorch. | |||