diff --git a/data/test/images/table_recognition.jpg b/data/test/images/table_recognition.jpg new file mode 100755 index 00000000..9978796f --- /dev/null +++ b/data/test/images/table_recognition.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4b7e23f02a35136707ac7862e0a8468797f239e89497351847cfacb2a9c24f6 +size 202112 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 33b1b3a3..5b56e09a 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -151,6 +151,7 @@ class Pipelines(object): image_denoise = 'nafnet-image-denoise' person_image_cartoon = 'unet-person-image-cartoon' ocr_detection = 'resnet18-ocr-detection' + table_recognition = 'dla34-table-recognition' action_recognition = 'TAdaConv_action-recognition' animal_recognition = 'resnet101-animal-recognition' general_recognition = 'resnet101-general-recognition' diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py index 377eff6f..e3251e48 100644 --- a/modelscope/outputs/outputs.py +++ b/modelscope/outputs/outputs.py @@ -59,6 +59,7 @@ TASK_OUTPUTS = { # [x1, y1, x2, y2, x3, y3, x4, y4] # } Tasks.ocr_detection: [OutputKeys.POLYGONS], + Tasks.table_recognition: [OutputKeys.POLYGONS], # ocr recognition result for single sample # { diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 70f8f11c..8b097bfc 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -82,6 +82,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_unet_person-image-cartoon_compound-models'), Tasks.ocr_detection: (Pipelines.ocr_detection, 'damo/cv_resnet18_ocr-detection-line-level_damo'), + Tasks.table_recognition: + (Pipelines.table_recognition, + 'damo/cv_dla34_table-structure-recognition_cycle-centernet'), Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), Tasks.feature_extraction: (Pipelines.feature_extraction, 'damo/pert_feature-extraction_base-test'), diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 5e9220bd..e196e8f7 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -41,6 +41,7 @@ if TYPE_CHECKING: from .live_category_pipeline import LiveCategoryPipeline from .ocr_detection_pipeline import OCRDetectionPipeline from .ocr_recognition_pipeline import OCRRecognitionPipeline + from .table_recognition_pipeline import TableRecognitionPipeline from .skin_retouching_pipeline import SkinRetouchingPipeline from .tinynas_classification_pipeline import TinynasClassificationPipeline from .video_category_pipeline import VideoCategoryPipeline @@ -108,6 +109,7 @@ else: 'image_inpainting_pipeline': ['ImageInpaintingPipeline'], 'ocr_detection_pipeline': ['OCRDetectionPipeline'], 'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], + 'table_recognition_pipeline': ['TableRecognitionPipeline'], 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], 'video_category_pipeline': ['VideoCategoryPipeline'], diff --git a/modelscope/pipelines/cv/ocr_utils/model_dla34.py b/modelscope/pipelines/cv/ocr_utils/model_dla34.py new file mode 100644 index 00000000..05d08abb --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/model_dla34.py @@ -0,0 +1,655 @@ +# ------------------------------------------------------------------------------ +# The implementation is adopted from CenterNet, +# made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git +# ------------------------------------------------------------------------------ + +import math +from os.path import join + +import numpy as np +import torch +from torch import nn + +BatchNorm = nn.BatchNorm2d + + +class BasicBlock(nn.Module): + + def __init__(self, inplanes, planes, stride=1, dilation=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, + planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation) + self.bn1 = BatchNorm(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=1, + padding=dilation, + bias=False, + dilation=dilation) + self.bn2 = BatchNorm(planes) + self.stride = stride + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 2 + + def __init__(self, inplanes, planes, stride=1, dilation=1): + super(Bottleneck, self).__init__() + expansion = Bottleneck.expansion + bottle_planes = planes // expansion + self.conv1 = nn.Conv2d( + inplanes, bottle_planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm(bottle_planes) + self.conv2 = nn.Conv2d( + bottle_planes, + bottle_planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation) + self.bn2 = BatchNorm(bottle_planes) + self.conv3 = nn.Conv2d( + bottle_planes, planes, kernel_size=1, bias=False) + self.bn3 = BatchNorm(planes) + self.relu = nn.ReLU(inplace=True) + self.stride = stride + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += residual + out = self.relu(out) + + return out + + +class BottleneckX(nn.Module): + expansion = 2 + cardinality = 32 + + def __init__(self, inplanes, planes, stride=1, dilation=1): + super(BottleneckX, self).__init__() + cardinality = BottleneckX.cardinality + bottle_planes = planes * cardinality // 32 + self.conv1 = nn.Conv2d( + inplanes, bottle_planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm(bottle_planes) + self.conv2 = nn.Conv2d( + bottle_planes, + bottle_planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + groups=cardinality) + self.bn2 = BatchNorm(bottle_planes) + self.conv3 = nn.Conv2d( + bottle_planes, planes, kernel_size=1, bias=False) + self.bn3 = BatchNorm(planes) + self.relu = nn.ReLU(inplace=True) + self.stride = stride + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += residual + out = self.relu(out) + + return out + + +class Root(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, residual): + super(Root, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + 1, + stride=1, + bias=False, + padding=(kernel_size - 1) // 2) + self.bn = BatchNorm(out_channels) + self.relu = nn.ReLU(inplace=True) + self.residual = residual + + def forward(self, *x): + children = x + x = self.conv(torch.cat(x, 1)) + x = self.bn(x) + if self.residual: + x += children[0] + x = self.relu(x) + + return x + + +class Tree(nn.Module): + + def __init__(self, + levels, + block, + in_channels, + out_channels, + stride=1, + level_root=False, + root_dim=0, + root_kernel_size=1, + dilation=1, + root_residual=False): + super(Tree, self).__init__() + if root_dim == 0: + root_dim = 2 * out_channels + if level_root: + root_dim += in_channels + if levels == 1: + self.tree1 = block( + in_channels, out_channels, stride, dilation=dilation) + self.tree2 = block( + out_channels, out_channels, 1, dilation=dilation) + else: + self.tree1 = Tree( + levels - 1, + block, + in_channels, + out_channels, + stride, + root_dim=0, + root_kernel_size=root_kernel_size, + dilation=dilation, + root_residual=root_residual) + self.tree2 = Tree( + levels - 1, + block, + out_channels, + out_channels, + root_dim=root_dim + out_channels, + root_kernel_size=root_kernel_size, + dilation=dilation, + root_residual=root_residual) + if levels == 1: + self.root = Root(root_dim, out_channels, root_kernel_size, + root_residual) + self.level_root = level_root + self.root_dim = root_dim + self.downsample = None + self.project = None + self.levels = levels + if stride > 1: + self.downsample = nn.MaxPool2d(stride, stride=stride) + if in_channels != out_channels: + self.project = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias=False), BatchNorm(out_channels)) + + def forward(self, x, residual=None, children=None): + children = [] if children is None else children + bottom = self.downsample(x) if self.downsample else x + residual = self.project(bottom) if self.project else bottom + if self.level_root: + children.append(bottom) + x1 = self.tree1(x, residual) + if self.levels == 1: + x2 = self.tree2(x1) + x = self.root(x2, x1, *children) + else: + children.append(x1) + x = self.tree2(x1, children=children) + return x + + +class DLA(nn.Module): + + def __init__(self, + levels, + channels, + num_classes=1000, + block=BasicBlock, + residual_root=False, + return_levels=False, + pool_size=7, + linear_root=False): + super(DLA, self).__init__() + self.channels = channels + self.return_levels = return_levels + self.num_classes = num_classes + self.base_layer = nn.Sequential( + nn.Conv2d( + 3, channels[0], kernel_size=7, stride=1, padding=3, + bias=False), BatchNorm(channels[0]), nn.ReLU(inplace=True)) + self.level0 = self._make_conv_level(channels[0], channels[0], + levels[0]) + self.level1 = self._make_conv_level( + channels[0], channels[1], levels[1], stride=2) + self.level2 = Tree( + levels[2], + block, + channels[1], + channels[2], + 2, + level_root=False, + root_residual=residual_root) + self.level3 = Tree( + levels[3], + block, + channels[2], + channels[3], + 2, + level_root=True, + root_residual=residual_root) + self.level4 = Tree( + levels[4], + block, + channels[3], + channels[4], + 2, + level_root=True, + root_residual=residual_root) + self.level5 = Tree( + levels[5], + block, + channels[4], + channels[5], + 2, + level_root=True, + root_residual=residual_root) + + self.avgpool = nn.AvgPool2d(pool_size) + self.fc = nn.Conv2d( + channels[-1], + num_classes, + kernel_size=1, + stride=1, + padding=0, + bias=True) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, BatchNorm): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_level(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes: + downsample = nn.Sequential( + nn.MaxPool2d(stride, stride=stride), + nn.Conv2d( + inplanes, planes, kernel_size=1, stride=1, bias=False), + BatchNorm(planes), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample=downsample)) + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): + modules = [] + for i in range(convs): + modules.extend([ + nn.Conv2d( + inplanes, + planes, + kernel_size=3, + stride=stride if i == 0 else 1, + padding=dilation, + bias=False, + dilation=dilation), + BatchNorm(planes), + nn.ReLU(inplace=True) + ]) + inplanes = planes + return nn.Sequential(*modules) + + def forward(self, x): + y = [] + x = self.base_layer(x) + for i in range(6): + x = getattr(self, 'level{}'.format(i))(x) + y.append(x) + if self.return_levels: + return y + else: + x = self.avgpool(x) + x = self.fc(x) + x = x.view(x.size(0), -1) + + return x + + +def dla34(pretrained, **kwargs): # DLA-34 + model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], + block=BasicBlock, + **kwargs) + return model + + +def dla46_c(pretrained=None, **kwargs): # DLA-46-C + Bottleneck.expansion = 2 + model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], + block=Bottleneck, + **kwargs) + return model + + +def dla46x_c(pretrained=None, **kwargs): # DLA-X-46-C + BottleneckX.expansion = 2 + model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], + block=BottleneckX, + **kwargs) + return model + + +def dla60x_c(pretrained, **kwargs): # DLA-X-60-C + BottleneckX.expansion = 2 + model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 64, 64, 128, 256], + block=BottleneckX, + **kwargs) + return model + + +def dla60(pretrained=None, **kwargs): # DLA-60 + Bottleneck.expansion = 2 + model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], + block=Bottleneck, + **kwargs) + return model + + +def dla60x(pretrained=None, **kwargs): # DLA-X-60 + BottleneckX.expansion = 2 + model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], + block=BottleneckX, + **kwargs) + return model + + +def dla102(pretrained=None, **kwargs): # DLA-102 + Bottleneck.expansion = 2 + model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], + block=Bottleneck, + residual_root=True, + **kwargs) + return model + + +def dla102x(pretrained=None, **kwargs): # DLA-X-102 + BottleneckX.expansion = 2 + model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], + block=BottleneckX, + residual_root=True, + **kwargs) + return model + + +def dla102x2(pretrained=None, **kwargs): # DLA-X-102 64 + BottleneckX.cardinality = 64 + model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], + block=BottleneckX, + residual_root=True, + **kwargs) + return model + + +def dla169(pretrained=None, **kwargs): # DLA-169 + Bottleneck.expansion = 2 + model = DLA([1, 1, 2, 3, 5, 1], [16, 32, 128, 256, 512, 1024], + block=Bottleneck, + residual_root=True, + **kwargs) + return model + + +def set_bn(bn): + global BatchNorm + BatchNorm = bn + dla.BatchNorm = bn + + +class Identity(nn.Module): + + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +def fill_up_weights(up): + w = up.weight.data + f = math.ceil(w.size(2) / 2) + c = (2 * f - 1 - f % 2) / (2. * f) + for i in range(w.size(2)): + for j in range(w.size(3)): + w[0, 0, i, j] = \ + (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) + for c in range(1, w.size(0)): + w[c, 0, :, :] = w[0, 0, :, :] + + +class IDAUp(nn.Module): + + def __init__(self, node_kernel, out_dim, channels, up_factors): + super(IDAUp, self).__init__() + self.channels = channels + self.out_dim = out_dim + for i, c in enumerate(channels): + if c == out_dim: + proj = Identity() + else: + proj = nn.Sequential( + nn.Conv2d(c, out_dim, kernel_size=1, stride=1, bias=False), + BatchNorm(out_dim), nn.ReLU(inplace=True)) + f = int(up_factors[i]) + if f == 1: + up = Identity() + else: + up = nn.ConvTranspose2d( + out_dim, + out_dim, + f * 2, + stride=f, + padding=f // 2, + output_padding=0, + groups=out_dim, + bias=False) + fill_up_weights(up) + setattr(self, 'proj_' + str(i), proj) + setattr(self, 'up_' + str(i), up) + + for i in range(1, len(channels)): + node = nn.Sequential( + nn.Conv2d( + out_dim * 2, + out_dim, + kernel_size=node_kernel, + stride=1, + padding=node_kernel // 2, + bias=False), BatchNorm(out_dim), nn.ReLU(inplace=True)) + setattr(self, 'node_' + str(i), node) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, BatchNorm): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, layers): + assert len(self.channels) == len(layers), \ + '{} vs {} layers'.format(len(self.channels), len(layers)) + layers = list(layers) + for i, l in enumerate(layers): + upsample = getattr(self, 'up_' + str(i)) + project = getattr(self, 'proj_' + str(i)) + layers[i] = upsample(project(l)) + x = layers[0] + y = [] + for i in range(1, len(layers)): + node = getattr(self, 'node_' + str(i)) + x = node(torch.cat([x, layers[i]], 1)) + y.append(x) + return x, y + + +class DLAUp(nn.Module): + + def __init__(self, channels, scales=(1, 2, 4, 8, 16), in_channels=None): + super(DLAUp, self).__init__() + if in_channels is None: + in_channels = channels + self.channels = channels + channels = list(channels) + scales = np.array(scales, dtype=int) + for i in range(len(channels) - 1): + j = -i - 2 + setattr( + self, 'ida_{}'.format(i), + IDAUp(3, channels[j], in_channels[j:], + scales[j:] // scales[j])) + scales[j + 1:] = scales[j] + in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]] + + def forward(self, layers): + layers = list(layers) + assert len(layers) > 1 + for i in range(len(layers) - 1): + ida = getattr(self, 'ida_{}'.format(i)) + x, y = ida(layers[-i - 2:]) + layers[-i - 1:] = y + return x + + +def fill_fc_weights(layers): + for m in layers.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, std=0.001) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +class DLASeg(nn.Module): + + def __init__(self, + base_name='dla34', + pretrained=False, + down_ratio=4, + head_conv=256): + super(DLASeg, self).__init__() + assert down_ratio in [2, 4, 8, 16] + self.heads = {'hm': 2, 'v2c': 8, 'c2v': 8, 'reg': 2} + self.first_level = int(np.log2(down_ratio)) + self.base = globals()[base_name]( + pretrained=pretrained, return_levels=True) + channels = self.base.channels + scales = [2**i for i in range(len(channels[self.first_level:]))] + self.dla_up = DLAUp(channels[self.first_level:], scales=scales) + + for head in self.heads: + classes = self.heads[head] + if head_conv > 0: + fc = nn.Sequential( + nn.Conv2d( + channels[self.first_level], + head_conv, + kernel_size=3, + padding=1, + bias=True), nn.ReLU(inplace=True), + nn.Conv2d( + head_conv, + classes, + kernel_size=1, + stride=1, + padding=0, + bias=True)) + if 'hm' in head: + fc[-1].bias.data.fill_(-2.19) + else: + fill_fc_weights(fc) + else: + fc = nn.Conv2d( + channels[self.first_level], + classes, + kernel_size=1, + stride=1, + padding=0, + bias=True) + if 'hm' in head: + fc.bias.data.fill_(-2.19) + else: + fill_fc_weights(fc) + self.__setattr__(head, fc) + + def forward(self, x): + x = self.base(x) + x = self.dla_up(x[self.first_level:]) + ret = {} + for head in self.heads: + ret[head] = self.__getattr__(head)(x) + return [ret] + + +def TableRecModel(): + model = DLASeg() + return model diff --git a/modelscope/pipelines/cv/ocr_utils/table_process.py b/modelscope/pipelines/cv/ocr_utils/table_process.py new file mode 100644 index 00000000..864ec71d --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/table_process.py @@ -0,0 +1,315 @@ +# ------------------------------------------------------------------------------ +# The implementation is adopted from CenterNet, +# made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git +# ------------------------------------------------------------------------------ + +import copy +import math +import random + +import cv2 +import numpy as np +import torch +import torch.nn as nn + + +def transform_preds(coords, center, scale, output_size, rot=0): + target_coords = np.zeros(coords.shape) + trans = get_affine_transform(center, scale, rot, output_size, inv=1) + for p in range(coords.shape[0]): + target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) + return target_coords + + +def get_affine_transform(center, + scale, + rot, + output_size, + shift=np.array([0, 0], dtype=np.float32), + inv=0): + if not isinstance(scale, np.ndarray) and not isinstance(scale, list): + scale = np.array([scale, scale], dtype=np.float32) + + scale_tmp = scale + src_w = scale_tmp[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.pi * rot / 180 + src_dir = get_dir([0, src_w * -0.5], rot_rad) + dst_dir = np.array([0, dst_w * -0.5], np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + dst = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale_tmp * shift + src[1, :] = center + src_dir + scale_tmp * shift + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir + + src[2:, :] = get_3rd_point(src[0, :], src[1, :]) + dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +def affine_transform(pt, t): + new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32).T + new_pt = np.dot(t, new_pt) + return new_pt[:2] + + +def get_dir(src_point, rot_rad): + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + + src_result = [0, 0] + src_result[0] = src_point[0] * cs - src_point[1] * sn + src_result[1] = src_point[0] * sn + src_point[1] * cs + + return src_result + + +def get_3rd_point(a, b): + direct = a - b + return b + np.array([-direct[1], direct[0]], dtype=np.float32) + + +def _sigmoid(x): + y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4) + return y + + +def _gather_feat(feat, ind, mask=None): + dim = feat.size(2) + ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) + feat = feat.gather(1, ind) + if mask is not None: + mask = mask.unsqueeze(2).expand_as(feat) + feat = feat[mask] + feat = feat.view(-1, dim) + return feat + + +def _tranpose_and_gather_feat(feat, ind): + feat = feat.permute(0, 2, 3, 1).contiguous() + feat = feat.view(feat.size(0), -1, feat.size(3)) + feat = _gather_feat(feat, ind) + return feat + + +def _nms(heat, kernel=3): + pad = (kernel - 1) // 2 + + hmax = nn.functional.max_pool2d( + heat, (kernel, kernel), stride=1, padding=pad) + keep = (hmax == heat).float() + return heat * keep, keep + + +def _topk(scores, K=40): + batch, cat, height, width = scores.size() + + topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K) + + topk_inds = topk_inds % (height * width) + topk_ys = (topk_inds / width).int().float() + topk_xs = (topk_inds % width).int().float() + + topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) + topk_clses = (topk_ind / K).int() + topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), + topk_ind).view(batch, K) + topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K) + topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K) + + return topk_score, topk_inds, topk_clses, topk_ys, topk_xs + + +def bbox_decode(heat, wh, reg=None, K=100): + batch, cat, height, width = heat.size() + + heat, keep = _nms(heat) + + scores, inds, clses, ys, xs = _topk(heat, K=K) + if reg is not None: + reg = _tranpose_and_gather_feat(reg, inds) + reg = reg.view(batch, K, 2) + xs = xs.view(batch, K, 1) + reg[:, :, 0:1] + ys = ys.view(batch, K, 1) + reg[:, :, 1:2] + else: + xs = xs.view(batch, K, 1) + 0.5 + ys = ys.view(batch, K, 1) + 0.5 + wh = _tranpose_and_gather_feat(wh, inds) + wh = wh.view(batch, K, 8) + clses = clses.view(batch, K, 1).float() + scores = scores.view(batch, K, 1) + + bboxes = torch.cat( + [ + xs - wh[..., 0:1], + ys - wh[..., 1:2], + xs - wh[..., 2:3], + ys - wh[..., 3:4], + xs - wh[..., 4:5], + ys - wh[..., 5:6], + xs - wh[..., 6:7], + ys - wh[..., 7:8], + ], + dim=2, + ) + detections = torch.cat([bboxes, scores, clses], dim=2) + + return detections, keep + + +def gbox_decode(mk, st_reg, reg=None, K=400): + batch, cat, height, width = mk.size() + mk, keep = _nms(mk) + scores, inds, clses, ys, xs = _topk(mk, K=K) + if reg is not None: + reg = _tranpose_and_gather_feat(reg, inds) + reg = reg.view(batch, K, 2) + xs = xs.view(batch, K, 1) + reg[:, :, 0:1] + ys = ys.view(batch, K, 1) + reg[:, :, 1:2] + else: + xs = xs.view(batch, K, 1) + 0.5 + ys = ys.view(batch, K, 1) + 0.5 + scores = scores.view(batch, K, 1) + clses = clses.view(batch, K, 1).float() + st_Reg = _tranpose_and_gather_feat(st_reg, inds) + bboxes = torch.cat( + [ + xs - st_Reg[..., 0:1], + ys - st_Reg[..., 1:2], + xs - st_Reg[..., 2:3], + ys - st_Reg[..., 3:4], + xs - st_Reg[..., 4:5], + ys - st_Reg[..., 5:6], + xs - st_Reg[..., 6:7], + ys - st_Reg[..., 7:8], + ], + dim=2, + ) + return torch.cat([xs, ys, bboxes, scores, clses], dim=2), keep + + +def bbox_post_process(bbox, c, s, h, w): + for i in range(bbox.shape[0]): + bbox[i, :, 0:2] = transform_preds(bbox[i, :, 0:2], c[i], s[i], (w, h)) + bbox[i, :, 2:4] = transform_preds(bbox[i, :, 2:4], c[i], s[i], (w, h)) + bbox[i, :, 4:6] = transform_preds(bbox[i, :, 4:6], c[i], s[i], (w, h)) + bbox[i, :, 6:8] = transform_preds(bbox[i, :, 6:8], c[i], s[i], (w, h)) + return bbox + + +def gbox_post_process(gbox, c, s, h, w): + for i in range(gbox.shape[0]): + gbox[i, :, 0:2] = transform_preds(gbox[i, :, 0:2], c[i], s[i], (w, h)) + gbox[i, :, 2:4] = transform_preds(gbox[i, :, 2:4], c[i], s[i], (w, h)) + gbox[i, :, 4:6] = transform_preds(gbox[i, :, 4:6], c[i], s[i], (w, h)) + gbox[i, :, 6:8] = transform_preds(gbox[i, :, 6:8], c[i], s[i], (w, h)) + gbox[i, :, 8:10] = transform_preds(gbox[i, :, 8:10], c[i], s[i], + (w, h)) + return gbox + + +def nms(dets, thresh): + if len(dets) < 2: + return dets + index_keep = [] + keep = [] + for i in range(len(dets)): + box = dets[i] + if box[-1] < thresh: + break + max_score_index = -1 + ctx = (dets[i][0] + dets[i][2] + dets[i][4] + dets[i][6]) / 4 + cty = (dets[i][1] + dets[i][3] + dets[i][5] + dets[i][7]) / 4 + for j in range(len(dets)): + if i == j or dets[j][-1] < thresh: + break + x1, y1 = dets[j][0], dets[j][1] + x2, y2 = dets[j][2], dets[j][3] + x3, y3 = dets[j][4], dets[j][5] + x4, y4 = dets[j][6], dets[j][7] + a = (x2 - x1) * (cty - y1) - (y2 - y1) * (ctx - x1) + b = (x3 - x2) * (cty - y2) - (y3 - y2) * (ctx - x2) + c = (x4 - x3) * (cty - y3) - (y4 - y3) * (ctx - x3) + d = (x1 - x4) * (cty - y4) - (y1 - y4) * (ctx - x4) + if (a > 0 and b > 0 and c > 0 and d > 0) or (a < 0 and b < 0 + and c < 0 and d < 0): + if dets[i][8] > dets[j][8] and max_score_index < 0: + max_score_index = i + elif dets[i][8] < dets[j][8]: + max_score_index = -2 + break + if max_score_index > -1: + index_keep.append(max_score_index) + elif max_score_index == -1: + index_keep.append(i) + for i in range(0, len(index_keep)): + keep.append(dets[index_keep[i]]) + return np.array(keep) + + +def group_bbox_by_gbox(bboxes, + gboxes, + score_thred=0.3, + v2c_dist_thred=2, + c2v_dist_thred=0.5): + + def point_in_box(box, point): + x1, y1, x2, y2 = box[0], box[1], box[2], box[3] + x3, y3, x4, y4 = box[4], box[5], box[6], box[7] + ctx, cty = point[0], point[1] + a = (x2 - x1) * (cty - y1) - (y2 - y1) * (ctx - x1) + b = (x3 - x2) * (cty - y2) - (y3 - y2) * (ctx - x2) + c = (x4 - x3) * (cty - y3) - (y4 - y3) * (ctx - x3) + d = (x1 - x4) * (cty - y4) - (y1 - y4) * (ctx - x4) + if (a > 0 and b > 0 and c > 0 and d > 0) or (a < 0 and b < 0 and c < 0 + and d < 0): + return True + else: + return False + + def get_distance(pt1, pt2): + return math.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0]) + + (pt1[1] - pt2[1]) * (pt1[1] - pt2[1])) + + dets = copy.deepcopy(bboxes) + sign = np.zeros((len(dets), 4)) + + for idx, gbox in enumerate(gboxes): # vertex x,y, gbox, score + if gbox[10] < score_thred: + break + vertex = [gbox[0], gbox[1]] + for i in range(0, 4): + center = [gbox[2 * i + 2], gbox[2 * i + 3]] + if get_distance(vertex, center) < v2c_dist_thred: + continue + for k, bbox in enumerate(dets): + if bbox[8] < score_thred: + break + if sum(sign[k]) == 4: + continue + w = (abs(bbox[6] - bbox[0]) + abs(bbox[4] - bbox[2])) / 2 + h = (abs(bbox[3] - bbox[1]) + abs(bbox[5] - bbox[7])) / 2 + m = max(w, h) + if point_in_box(bbox, center): + min_dist, min_id = 1e4, -1 + for j in range(0, 4): + dist = get_distance(vertex, + [bbox[2 * j], bbox[2 * j + 1]]) + if dist < min_dist: + min_dist = dist + min_id = j + if (min_id > -1 and min_dist < c2v_dist_thred * m + and sign[k][min_id] == 0): + bboxes[k][2 * min_id] = vertex[0] + bboxes[k][2 * min_id + 1] = vertex[1] + sign[k][min_id] = 1 + return bboxes diff --git a/modelscope/pipelines/cv/table_recognition_pipeline.py b/modelscope/pipelines/cv/table_recognition_pipeline.py new file mode 100644 index 00000000..1ee9a4f0 --- /dev/null +++ b/modelscope/pipelines/cv/table_recognition_pipeline.py @@ -0,0 +1,119 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.pipelines.cv.ocr_utils.model_dla34 import TableRecModel +from modelscope.pipelines.cv.ocr_utils.table_process import ( + bbox_decode, bbox_post_process, gbox_decode, gbox_post_process, + get_affine_transform, group_bbox_by_gbox, nms) +from modelscope.preprocessors import load_image +from modelscope.preprocessors.image import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.table_recognition, module_name=Pipelines.table_recognition) +class TableRecognitionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {model_path}') + + self.K = 1000 + self.MK = 4000 + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.infer_model = TableRecModel().to(self.device) + self.infer_model.eval() + checkpoint = torch.load(model_path, map_location=self.device) + if 'state_dict' in checkpoint: + self.infer_model.load_state_dict(checkpoint['state_dict']) + else: + self.infer_model.load_state_dict(checkpoint) + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + + mean = np.array([0.408, 0.447, 0.470], + dtype=np.float32).reshape(1, 1, 3) + std = np.array([0.289, 0.274, 0.278], + dtype=np.float32).reshape(1, 1, 3) + height, width = img.shape[0:2] + inp_height, inp_width = 1024, 1024 + c = np.array([width / 2., height / 2.], dtype=np.float32) + s = max(height, width) * 1.0 + + trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height]) + resized_image = cv2.resize(img, (width, height)) + inp_image = cv2.warpAffine( + resized_image, + trans_input, (inp_width, inp_height), + flags=cv2.INTER_LINEAR) + inp_image = ((inp_image / 255. - mean) / std).astype(np.float32) + + images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height, + inp_width) + images = torch.from_numpy(images).to(self.device) + meta = { + 'c': c, + 's': s, + 'input_height': inp_height, + 'input_width': inp_width, + 'out_height': inp_height // 4, + 'out_width': inp_width // 4 + } + + result = {'img': images, 'meta': meta} + + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + pred = self.infer_model(input['img']) + return {'results': pred, 'meta': input['meta']} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + output = inputs['results'][0] + meta = inputs['meta'] + hm = output['hm'].sigmoid_() + v2c = output['v2c'] + c2v = output['c2v'] + reg = output['reg'] + bbox, _ = bbox_decode(hm[:, 0:1, :, :], c2v, reg=reg, K=self.K) + gbox, _ = gbox_decode(hm[:, 1:2, :, :], v2c, reg=reg, K=self.MK) + + bbox = bbox.detach().cpu().numpy() + gbox = gbox.detach().cpu().numpy() + bbox = nms(bbox, 0.3) + bbox = bbox_post_process(bbox.copy(), [meta['c'].cpu().numpy()], + [meta['s']], meta['out_height'], + meta['out_width']) + gbox = gbox_post_process(gbox.copy(), [meta['c'].cpu().numpy()], + [meta['s']], meta['out_height'], + meta['out_width']) + bbox = group_bbox_by_gbox(bbox[0], gbox[0]) + + res = [] + for box in bbox: + if box[8] > 0.3: + res.append(box[0:8]) + + result = {OutputKeys.POLYGONS: np.array(res)} + return result diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index b1bccc4c..5072ebe1 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -16,6 +16,7 @@ class CVTasks(object): # ocr ocr_detection = 'ocr-detection' ocr_recognition = 'ocr-recognition' + table_recognition = 'table-recognition' # human face body related animal_recognition = 'animal-recognition' diff --git a/tests/pipelines/test_table_recognition.py b/tests/pipelines/test_table_recognition.py new file mode 100644 index 00000000..3c6ee74a --- /dev/null +++ b/tests/pipelines/test_table_recognition.py @@ -0,0 +1,41 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class TableRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id = 'damo/cv_dla34_table-structure-recognition_cycle-centernet' + self.test_image = 'data/test/images/table_recognition.jpg' + self.task = Tasks.table_recognition + + def pipeline_inference(self, pipe: Pipeline, input_location: str): + result = pipe(input_location) + print('table recognition results: ') + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + table_recognition = pipeline( + Tasks.table_recognition, model=self.model_id) + self.pipeline_inference(table_recognition, self.test_image) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + table_recognition = pipeline(Tasks.table_recognition) + self.pipeline_inference(table_recognition, self.test_image) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/run_config.yaml b/tests/run_config.yaml index e0529f19..2e06b88e 100644 --- a/tests/run_config.yaml +++ b/tests/run_config.yaml @@ -39,6 +39,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run - test_automatic_speech_recognition.py - test_image_matting.py - test_skin_retouching.py + - test_table_recognition.py envs: default: # default env, case not in other env will in default, pytorch.