diff --git a/data/test/images/ocr_detection.jpg b/data/test/images/ocr_detection.jpg new file mode 100644 index 00000000..c347810e --- /dev/null +++ b/data/test/images/ocr_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c8435db5583400be5d11a2c17910c96133b462c8a99ccaf0e19f4aac34e0a94 +size 141149 diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 8897cf31..5e1fbd87 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -28,6 +28,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.image_generation: ('person-image-cartoon', 'damo/cv_unet_person-image-cartoon_compound-models'), + Tasks.ocr_detection: ('ocr-detection', + 'damo/cv_resnet18_ocr-detection-line-level_damo'), } diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 79c85c19..767c90d7 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -1,2 +1,3 @@ from .image_cartoon_pipeline import ImageCartoonPipeline from .image_matting_pipeline import ImageMattingPipeline +from .ocr_detection_pipeline import OCRDetectionPipeline diff --git a/modelscope/pipelines/cv/ocr_detection_pipeline.py b/modelscope/pipelines/cv/ocr_detection_pipeline.py new file mode 100644 index 00000000..9728e441 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_detection_pipeline.py @@ -0,0 +1,167 @@ +import math +import os +import os.path as osp +import sys +from typing import Any, Dict, List, Tuple, Union + +import cv2 +import numpy as np +import PIL +import tensorflow as tf +import tf_slim as slim + +from modelscope.pipelines.base import Input +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from ..base import Pipeline +from ..builder import PIPELINES +from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 +tf.compat.v1.disable_eager_execution() + +logger = get_logger() + +# constant +RBOX_DIM = 5 +OFFSET_DIM = 6 +WORD_POLYGON_DIM = 8 +OFFSET_VARIANCE = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1] + +FLAGS = tf.app.flags.FLAGS +tf.app.flags.DEFINE_float('node_threshold', 0.4, + 'Confidence threshold for nodes') +tf.app.flags.DEFINE_float('link_threshold', 0.6, + 'Confidence threshold for links') + + +@PIPELINES.register_module( + Tasks.ocr_detection, module_name=Tasks.ocr_detection) +class OCRDetectionPipeline(Pipeline): + + def __init__(self, model: str): + super().__init__(model=model) + model_path = osp.join( + osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER), + 'checkpoint-80000') + + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + self._session = tf.Session(config=config) + global_step = tf.get_variable( + 'global_step', [], + initializer=tf.constant_initializer(0), + dtype=tf.int64, + trainable=False) + variable_averages = tf.train.ExponentialMovingAverage( + 0.997, global_step) + self.input_images = tf.placeholder( + tf.float32, shape=[1, 1024, 1024, 3], name='input_images') + self.output = {} + + # detector + detector = model_resnet_mutex_v4_linewithchar.SegLinkDetector() + all_maps = detector.build_model(self.input_images, is_training=False) + + # decode local predictions + all_nodes, all_links, all_reg = [], [], [] + for i, maps in enumerate(all_maps): + cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2] + reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) + + cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) + + lnk_prob_pos = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 4])[:, :2]) + lnk_prob_mut = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 4])[:, 2:]) + lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1) + + all_nodes.append(cls_prob) + all_links.append(lnk_prob) + all_reg.append(reg_maps) + + # decode segments and links + image_size = tf.shape(self.input_images)[1:3] + segments, group_indices, segment_counts, _ = ops.decode_segments_links_python( + image_size, + all_nodes, + all_links, + all_reg, + anchor_sizes=list(detector.anchor_sizes)) + + # combine segments + combined_rboxes, combined_counts = ops.combine_segments_python( + segments, group_indices, segment_counts) + self.output['combined_rboxes'] = combined_rboxes + self.output['combined_counts'] = combined_counts + + with self._session.as_default() as sess: + logger.info(f'loading model from {model_path}') + # load model + model_loader = tf.train.Saver( + variable_averages.variables_to_restore()) + model_loader.restore(sess, model_path) + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + img = np.array(load_image(input)) + elif isinstance(input, PIL.Image.Image): + img = np.array(input.convert('RGB')) + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + img = input[:, :, ::-1] # in rgb order + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + h, w, c = img.shape + img_pad = np.zeros((max(h, w), max(h, w), 3), dtype=np.float32) + img_pad[:h, :w, :] = img + + resize_size = 1024 + img_pad_resize = cv2.resize(img_pad, (resize_size, resize_size)) + img_pad_resize = cv2.cvtColor(img_pad_resize, cv2.COLOR_RGB2BGR) + img_pad_resize = img_pad_resize - np.array([123.68, 116.78, 103.94], + dtype=np.float32) + + resize_size = tf.stack([resize_size, resize_size]) + orig_size = tf.stack([max(h, w), max(h, w)]) + self.output['orig_size'] = orig_size + self.output['resize_size'] = resize_size + + result = {'img': np.expand_dims(img_pad_resize, axis=0)} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + with self._session.as_default(): + feed_dict = {self.input_images: input['img']} + sess_outputs = self._session.run(self.output, feed_dict=feed_dict) + return sess_outputs + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + rboxes = inputs['combined_rboxes'][0] + count = inputs['combined_counts'][0] + rboxes = rboxes[:count, :] + + # convert rboxes to polygons and find its coordinates on the original image + orig_h, orig_w = inputs['orig_size'] + resize_h, resize_w = inputs['resize_size'] + polygons = utils.rboxes_to_polygons(rboxes) + scale_y = float(orig_h) / float(resize_h) + scale_x = float(orig_w) / float(resize_w) + + # confine polygons inside image + polygons[:, ::2] = np.maximum( + 0, np.minimum(polygons[:, ::2] * scale_x, orig_w - 1)) + polygons[:, 1::2] = np.maximum( + 0, np.minimum(polygons[:, 1::2] * scale_y, orig_h - 1)) + polygons = np.round(polygons).astype(np.int32) + + # nms + dt_n9 = [o + [utils.cal_width(o)] for o in polygons.tolist()] + dt_nms = utils.nms_python(dt_n9) + dt_polygons = np.array([o[:8] for o in dt_nms]) + + result = {'det_polygons': dt_polygons} + return result diff --git a/modelscope/pipelines/cv/ocr_utils/__init__.py b/modelscope/pipelines/cv/ocr_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py b/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py new file mode 100644 index 00000000..50b8ba02 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py @@ -0,0 +1,158 @@ +import tensorflow as tf +import tf_slim as slim + +from . import ops, resnet18_v1, resnet_utils + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + +# constants +OFFSET_DIM = 6 + +N_LOCAL_LINKS = 8 +N_CROSS_LINKS = 4 +N_SEG_CLASSES = 2 +N_LNK_CLASSES = 4 + +POS_LABEL = 1 +NEG_LABEL = 0 + + +class SegLinkDetector(): + + def __init__(self): + self.anchor_sizes = [6., 11.84210526, 23.68421053, 45., 90., 150.] + + def _detection_classifier(self, + maps, + ksize, + weight_decay, + cross_links=False, + scope=None): + + with tf.variable_scope(scope): + seg_depth = N_SEG_CLASSES + if cross_links: + lnk_depth = N_LNK_CLASSES * (N_LOCAL_LINKS + N_CROSS_LINKS) + else: + lnk_depth = N_LNK_CLASSES * N_LOCAL_LINKS + reg_depth = OFFSET_DIM + map_depth = maps.get_shape()[3] + inter_maps, inter_relu = ops.conv2d( + maps, map_depth, 256, 1, 1, 'SAME', scope='conv_inter') + + dir_maps, dir_relu = ops.conv2d( + inter_relu, 256, 2, ksize, 1, 'SAME', scope='conv_dir') + cen_maps, cen_relu = ops.conv2d( + inter_relu, 256, 2, ksize, 1, 'SAME', scope='conv_cen') + pol_maps, pol_relu = ops.conv2d( + inter_relu, 256, 8, ksize, 1, 'SAME', scope='conv_pol') + concat_relu = tf.concat([dir_relu, cen_relu, pol_relu], axis=-1) + _, lnk_embedding = ops.conv_relu( + concat_relu, 12, 256, 1, 1, scope='lnk_embedding') + lnk_maps, lnk_relu = ops.conv2d( + inter_relu + lnk_embedding, + 256, + lnk_depth, + ksize, + 1, + 'SAME', + scope='conv_lnk') + + char_seg_maps, char_seg_relu = ops.conv2d( + inter_relu, + 256, + seg_depth, + ksize, + 1, + 'SAME', + scope='conv_char_cls') + char_reg_maps, char_reg_relu = ops.conv2d( + inter_relu, + 256, + reg_depth, + ksize, + 1, + 'SAME', + scope='conv_char_reg') + concat_char_relu = tf.concat([char_seg_relu, char_reg_relu], + axis=-1) + _, char_embedding = ops.conv_relu( + concat_char_relu, 8, 256, 1, 1, scope='conv_char_embedding') + seg_maps, seg_relu = ops.conv2d( + inter_relu + char_embedding, + 256, + seg_depth, + ksize, + 1, + 'SAME', + scope='conv_cls') + reg_maps, reg_relu = ops.conv2d( + inter_relu + char_embedding, + 256, + reg_depth, + ksize, + 1, + 'SAME', + scope='conv_reg') + + return seg_relu, lnk_relu, reg_relu + + def _build_cnn(self, images, weight_decay, is_training): + with slim.arg_scope( + resnet18_v1.resnet_arg_scope(weight_decay=weight_decay)): + logits, end_points = resnet18_v1.resnet_v1_18( + images, is_training=is_training, scope='resnet_v1_18') + + outputs = { + 'conv3_3': end_points['pool1'], + 'conv4_3': end_points['pool2'], + 'fc7': end_points['pool3'], + 'conv8_2': end_points['pool4'], + 'conv9_2': end_points['pool5'], + 'conv10_2': end_points['pool6'], + } + return outputs + + def build_model(self, images, is_training=True, scope=None): + + weight_decay = 5e-4 # FLAGS.weight_decay + cnn_outputs = self._build_cnn(images, weight_decay, is_training) + det_0 = self._detection_classifier( + cnn_outputs['conv3_3'], + 3, + weight_decay, + cross_links=False, + scope='dete_0') + det_1 = self._detection_classifier( + cnn_outputs['conv4_3'], + 3, + weight_decay, + cross_links=True, + scope='dete_1') + det_2 = self._detection_classifier( + cnn_outputs['fc7'], + 3, + weight_decay, + cross_links=True, + scope='dete_2') + det_3 = self._detection_classifier( + cnn_outputs['conv8_2'], + 3, + weight_decay, + cross_links=True, + scope='dete_3') + det_4 = self._detection_classifier( + cnn_outputs['conv9_2'], + 3, + weight_decay, + cross_links=True, + scope='dete_4') + det_5 = self._detection_classifier( + cnn_outputs['conv10_2'], + 3, + weight_decay, + cross_links=True, + scope='dete_5') + outputs = [det_0, det_1, det_2, det_3, det_4, det_5] + return outputs diff --git a/modelscope/pipelines/cv/ocr_utils/ops.py b/modelscope/pipelines/cv/ocr_utils/ops.py new file mode 100644 index 00000000..2bc8a8bf --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/ops.py @@ -0,0 +1,1098 @@ +import math +import os +import shutil +import uuid + +import cv2 +import numpy as np +import tensorflow as tf + +from . import utils + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + +FLAGS = tf.app.flags.FLAGS +tf.app.flags.DEFINE_string('weight_init_method', 'xavier', + 'Weight initialization method') + +# constants +OFFSET_DIM = 6 +RBOX_DIM = 5 + +N_LOCAL_LINKS = 8 +N_CROSS_LINKS = 4 +N_SEG_CLASSES = 2 +N_LNK_CLASSES = 4 + +MATCH_STATUS_POS = 1 +MATCH_STATUS_NEG = -1 +MATCH_STATUS_IGNORE = 0 +MUT_LABEL = 3 +POS_LABEL = 1 +NEG_LABEL = 0 + +N_DET_LAYERS = 6 + + +def load_oplib(lib_name): + """ + Load TensorFlow operator library. + """ + # use absolute path so that ops.py can be called from other directory + lib_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + 'lib{0}.so'.format(lib_name)) + # duplicate library with a random new name so that + # a running program will not be interrupted when the original library is updated + lib_copy_path = '/tmp/lib{0}_{1}.so'.format( + str(uuid.uuid4())[:8], LIB_NAME) + shutil.copyfile(lib_path, lib_copy_path) + oplib = tf.load_op_library(lib_copy_path) + return oplib + + +def _nn_variable(name, shape, init_method, collection=None, **kwargs): + """ + Create or reuse a variable + ARGS + name: variable name + shape: variable shape + init_method: 'zero', 'kaiming', 'xavier', or (mean, std) + collection: if not none, add variable to this collection + kwargs: extra paramters passed to tf.get_variable + RETURN + var: a new or existing variable + """ + if init_method == 'zero': + initializer = tf.constant_initializer(0.0) + elif init_method == 'kaiming': + if len(shape) == 4: # convolutional filters + kh, kw, n_in = shape[:3] + init_std = math.sqrt(2.0 / (kh * kw * n_in)) + elif len(shape) == 2: # linear weights + n_in, n_out = shape + init_std = math.sqrt(1.0 / n_out) + else: + raise 'Unsupported shape' + initializer = tf.truncated_normal_initializer(0.0, init_std) + elif init_method == 'xavier': + if len(shape) == 4: + initializer = tf.keras.initializers.glorot_normal() + else: + initializer = tf.keras.initializers.glorot_normal() + elif isinstance(init_method, tuple): + assert (len(init_method) == 2) + initializer = tf.truncated_normal_initializer(init_method[0], + init_method[1]) + else: + raise 'Unsupported weight initialization method: ' + init_method + + var = tf.get_variable(name, shape=shape, initializer=initializer, **kwargs) + if collection is not None: + tf.add_to_collection(collection, var) + + return var + + +def conv2d(x, + n_in, + n_out, + ksize, + stride=1, + padding='SAME', + weight_init=None, + bias=True, + relu=False, + scope=None, + **kwargs): + weight_init = weight_init or FLAGS.weight_init_method + trainable = kwargs.get('trainable', True) + # input_dim = n_in + if (padding == 'SAME'): + in_height = x.get_shape()[1] + in_width = x.get_shape()[2] + if (in_height % stride == 0): + pad_along_height = max(ksize - stride, 0) + else: + pad_along_height = max(ksize - (in_height % stride), 0) + if (in_width % stride == 0): + pad_along_width = max(ksize - stride, 0) + else: + pad_along_width = max(ksize - (in_width % stride), 0) + pad_bottom = pad_along_height // 2 + pad_top = pad_along_height - pad_bottom + pad_right = pad_along_width // 2 + pad_left = pad_along_width - pad_right + paddings = tf.constant([[0, 0], [pad_top, pad_bottom], + [pad_left, pad_right], [0, 0]]) + input_padded = tf.pad(x, paddings, 'CONSTANT') + else: + input_padded = x + + with tf.variable_scope(scope or 'conv2d'): + # convolution + kernel = _nn_variable( + 'weight', [ksize, ksize, n_in, n_out], + weight_init, + collection='weights' if trainable else None, + **kwargs) + yc = tf.nn.conv2d( + input_padded, kernel, [1, stride, stride, 1], padding='VALID') + # add bias + if bias is True: + bias = _nn_variable( + 'bias', [n_out], + 'zero', + collection='biases' if trainable else None, + **kwargs) + yb = tf.nn.bias_add(yc, bias) + # apply ReLU + y = yb + if relu is True: + y = tf.nn.relu(yb) + return yb, y + + +def group_conv2d_relu(x, + n_in, + n_out, + ksize, + stride=1, + group=4, + padding='SAME', + weight_init=None, + bias=True, + relu=False, + name='group_conv2d', + **kwargs): + group_axis = len(x.get_shape()) - 1 + splits = tf.split(x, [int(n_in / group)] * group, group_axis) + + conv_list = [] + for i in range(group): + conv_split, relu_split = conv2d( + splits[i], + n_in / group, + n_out / group, + ksize=ksize, + stride=stride, + padding=padding, + weight_init=weight_init, + bias=bias, + relu=relu, + scope='%s_%d' % (name, i)) + conv_list.append(conv_split) + conv = tf.concat(values=conv_list, axis=group_axis, name=name + '_concat') + relu = tf.nn.relu(conv) + return conv, relu + + +def group_conv2d_bn_relu(x, + n_in, + n_out, + ksize, + stride=1, + group=4, + padding='SAME', + weight_init=None, + bias=True, + relu=False, + name='group_conv2d', + **kwargs): + group_axis = len(x.get_shape()) - 1 + splits = tf.split(x, [int(n_in / group)] * group, group_axis) + + conv_list = [] + for i in range(group): + conv_split, relu_split = conv2d( + splits[i], + n_in / group, + n_out / group, + ksize=ksize, + stride=stride, + padding=padding, + weight_init=weight_init, + bias=bias, + relu=relu, + scope='%s_%d' % (name, i)) + conv_list.append(conv_split) + conv = tf.concat(values=conv_list, axis=group_axis, name=name + '_concat') + with tf.variable_scope(name + '_bn'): + bn = tf.layers.batch_normalization( + conv, momentum=0.9, epsilon=1e-5, scale=True, training=True) + relu = tf.nn.relu(bn) + return conv, relu + + +def next_conv(x, + n_in, + n_out, + ksize, + stride=1, + group=4, + padding='SAME', + weight_init=None, + bias=True, + relu=False, + name='next_conv2d', + **kwargs): + conv_a, relu_a = conv_relu( + x, + n_in, + n_in / 2, + ksize=1, + stride=1, + padding=padding, + weight_init=weight_init, + bias=bias, + relu=relu, + scope=name + '_a', + **kwargs) + + conv_b, relu_b = group_conv2d_relu( + relu_a, + n_in / 2, + n_out / 2, + ksize=ksize, + stride=stride, + group=group, + padding=padding, + weight_init=weight_init, + bias=bias, + relu=relu, + name=name + '_b', + **kwargs) + + conv_c, relu_c = conv_relu( + relu_b, + n_out / 2, + n_out, + ksize=1, + stride=1, + padding=padding, + weight_init=weight_init, + bias=bias, + relu=relu, + scope=name + '_c', + **kwargs) + + return conv_c, relu_c + + +def next_conv_bn(x, + n_in, + n_out, + ksize, + stride=1, + group=4, + padding='SAME', + weight_init=None, + bias=True, + relu=False, + name='next_conv2d', + **kwargs): + conv_a, relu_a = conv_bn_relu( + x, + n_in, + n_in / 2, + ksize=1, + stride=1, + padding=padding, + weight_init=weight_init, + bias=bias, + relu=relu, + scope=name + '_a', + **kwargs) + + conv_b, relu_b = group_conv2d_bn_relu( + relu_a, + n_in / 2, + n_out / 2, + ksize=ksize, + stride=stride, + group=group, + padding=padding, + weight_init=weight_init, + bias=bias, + relu=relu, + name=name + '_b', + **kwargs) + + conv_c, relu_c = conv_bn_relu( + relu_b, + n_out / 2, + n_out, + ksize=1, + stride=1, + padding=padding, + weight_init=weight_init, + bias=bias, + relu=relu, + scope=name + '_c', + **kwargs) + + return conv_c, relu_c + + +def conv2d_ori(x, + n_in, + n_out, + ksize, + stride=1, + padding='SAME', + weight_init=None, + bias=True, + relu=False, + scope=None, + **kwargs): + weight_init = weight_init or FLAGS.weight_init_method + trainable = kwargs.get('trainable', True) + + with tf.variable_scope(scope or 'conv2d'): + # convolution + kernel = _nn_variable( + 'weight', [ksize, ksize, n_in, n_out], + weight_init, + collection='weights' if trainable else None, + **kwargs) + y = tf.nn.conv2d(x, kernel, [1, stride, stride, 1], padding=padding) + # add bias + if bias is True: + bias = _nn_variable( + 'bias', [n_out], + 'zero', + collection='biases' if trainable else None, + **kwargs) + y = tf.nn.bias_add(y, bias) + # apply ReLU + if relu is True: + y = tf.nn.relu(y) + return y + + +def conv_relu(*args, **kwargs): + kwargs['relu'] = True + if 'scope' not in kwargs: + kwargs['scope'] = 'conv_relu' + return conv2d(*args, **kwargs) + + +def conv_bn_relu(*args, **kwargs): + kwargs['relu'] = True + if 'scope' not in kwargs: + kwargs['scope'] = 'conv_relu' + conv, relu = conv2d(*args, **kwargs) + with tf.variable_scope(kwargs['scope'] + '_bn'): + bn = tf.layers.batch_normalization( + conv, momentum=0.9, epsilon=1e-5, scale=True, training=True) + bn_relu = tf.nn.relu(bn) + return bn, bn_relu + + +def conv_relu_ori(*args, **kwargs): + kwargs['relu'] = True + if 'scope' not in kwargs: + kwargs['scope'] = 'conv_relu' + return conv2d_ori(*args, **kwargs) + + +def atrous_conv2d(x, + n_in, + n_out, + ksize, + dilation, + padding='SAME', + weight_init=None, + bias=True, + relu=False, + scope=None, + **kwargs): + weight_init = weight_init or FLAGS.weight_init_method + trainable = kwargs.get('trainable', True) + with tf.variable_scope(scope or 'atrous_conv2d'): + # atrous convolution + kernel = _nn_variable( + 'weight', [ksize, ksize, n_in, n_out], + weight_init, + collection='weights' if trainable else None, + **kwargs) + y = tf.nn.atrous_conv2d(x, kernel, dilation, padding=padding) + # add bias + if bias is True: + bias = _nn_variable( + 'bias', [n_out], + 'zero', + collection='biases' if trainable else None, + **kwargs) + y = tf.nn.bias_add(y, bias) + # apply ReLU + if relu is True: + y = tf.nn.relu(y) + return y + + +def avg_pool(x, ksize, stride, padding='SAME', scope=None): + with tf.variable_scope(scope or 'avg_pool'): + y = tf.nn.avg_pool(x, [1, ksize, ksize, 1], [1, stride, stride, 1], + padding) + return y + + +def max_pool(x, ksize, stride, padding='SAME', scope=None): + with tf.variable_scope(scope or 'max_pool'): + y = tf.nn.max_pool(x, [1, ksize, ksize, 1], [1, stride, stride, 1], + padding) + return y + + +def score_loss(gt_labels, match_scores, n_classes): + """ + Classification loss + ARGS + gt_labels: int32 [n] + match_scores: [n, n_classes] + RETURN + loss + """ + embeddings = tf.one_hot(tf.cast(gt_labels, tf.int64), n_classes, 1.0, 0.0) + losses = tf.nn.softmax_cross_entropy_with_logits(match_scores, embeddings) + return tf.reduce_sum(losses) + + +def smooth_l1_loss(offsets, gt_offsets, scope=None): + """ + Smooth L1 loss between offsets and encoded_gt + ARGS + offsets: [m?, 5], predicted offsets for one example + gt_offsets: [m?, 5], correponding groundtruth offsets + RETURN + loss: scalar + """ + with tf.variable_scope(scope or 'smooth_l1_loss'): + gt_offsets = tf.stop_gradient(gt_offsets) + diff = tf.abs(offsets - gt_offsets) + lesser_mask = tf.cast(tf.less(diff, 1.0), tf.float32) + larger_mask = 1.0 - lesser_mask + losses1 = (0.5 * tf.square(diff)) * lesser_mask + losses2 = (diff - 0.5) * larger_mask + return tf.reduce_sum(losses1 + losses2, 1) + + +def polygon_to_rboxe(polygon): + x1 = polygon[0] + y1 = polygon[1] + x2 = polygon[2] + y2 = polygon[3] + x3 = polygon[4] + y3 = polygon[5] + x4 = polygon[6] + y4 = polygon[7] + c_x = (x1 + x2 + x3 + x4) / 4 + c_y = (y1 + y2 + y3 + y4) / 4 + w1 = point_dist(x1, y1, x2, y2) + w2 = point_dist(x3, y3, x4, y4) + h1 = point_line_dist(c_x, c_y, x1, y1, x2, y2) + h2 = point_line_dist(c_x, c_y, x3, y3, x4, y4) + h = h1 + h2 + w = (w1 + w2) / 2 + theta1 = np.arctan2(y2 - y1, x2 - x1) + theta2 = np.arctan2(y3 - y4, x3 - x4) + theta = (theta1 + theta2) / 2 + return np.array([c_x, c_y, w, h, theta]) + + +def point_dist(x1, y1, x2, y2): + return np.sqrt((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)) + + +def point_line_dist(px, py, x1, y1, x2, y2): + eps = 1e-6 + dx = x2 - x1 + dy = y2 - y1 + div = np.sqrt(dx * dx + dy * dy) + eps + dist = np.abs(px * dy - py * dx + x2 * y1 - y2 * x1) / div + return dist + + +def get_combined_polygon(rboxes, resize_size): + image_w = resize_size[1] + image_h = resize_size[0] + img = np.zeros((image_h, image_w, 3), np.uint8) + for i in range(rboxes.shape[0]): + segment = np.reshape( + np.array(utils.rboxes_to_polygons(rboxes)[i, :], np.int32), + (-1, 1, 2)) + cv2.drawContours(img, [segment], 0, (255, 255, 255), -1) + img2gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + ret, thresh = cv2.threshold(img2gray, 127, 255, cv2.THRESH_BINARY) + im2, contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, + cv2.CHAIN_APPROX_SIMPLE) + if len(contours) > 0: + cnt = contours[0] + max_area = cv2.contourArea(cnt) + # get max_area + for cont in contours: + if cv2.contourArea(cont) > max_area: + cnt = cont + max_area = cv2.contourArea(cont) + rect = cv2.minAreaRect(cnt) + combined_polygon = np.array(cv2.boxPoints(rect)).reshape(-1) + else: + combined_polygon = np.array([0, 0, 0, 0, 0, 0, 0, 0]) + + return combined_polygon + + +def combine_segs(segs): + segs = np.asarray(segs) + assert segs.ndim == 2, 'invalid segs ndim' + assert segs.shape[-1] == 6, 'invalid segs shape' + + if len(segs) == 1: + cx = segs[0, 0] + cy = segs[0, 1] + w = segs[0, 2] + h = segs[0, 3] + theta_sin = segs[0, 4] + theta_cos = segs[0, 5] + theta = np.arctan2(theta_sin, theta_cos) + return np.array([cx, cy, w, h, theta]) + + # find the best straight line fitting all center points: y = kx + b + cxs = segs[:, 0] + cys = segs[:, 1] + + theta_coss = segs[:, 4] + theta_sins = segs[:, 5] + + bar_theta = np.arctan2(theta_sins.sum(), theta_coss.sum()) + k = np.tan(bar_theta) + b = np.mean(cys - k * cxs) + + proj_xs = (k * cys + cxs - k * b) / (k**2 + 1) + proj_ys = (k * k * cys + k * cxs + b) / (k**2 + 1) + proj_points = np.stack((proj_xs, proj_ys), -1) + + # find the max distance + max_dist = -1 + idx1 = -1 + idx2 = -1 + + for i in range(len(proj_points)): + point1 = proj_points[i, :] + for j in range(i + 1, len(proj_points)): + point2 = proj_points[j, :] + dist = np.sqrt(np.sum((point1 - point2)**2)) + if dist > max_dist: + idx1 = i + idx2 = j + max_dist = dist + assert idx1 >= 0 and idx2 >= 0 + # the bbox: bcx, bcy, bw, bh, average_theta + seg1 = segs[idx1, :] + seg2 = segs[idx2, :] + bcx, bcy = (seg1[:2] + seg2[:2]) / 2.0 + bh = np.mean(segs[:, 3]) + bw = max_dist + (seg1[2] + seg2[2]) / 2.0 + return bcx, bcy, bw, bh, bar_theta + + +def combine_segments_batch(segments_batch, group_indices_batch, + segment_counts_batch): + batch_size = 1 + combined_rboxes_batch = [] + combined_counts_batch = [] + for image_id in range(batch_size): + group_count = segment_counts_batch[image_id] + segments = segments_batch[image_id, :, :] + group_indices = group_indices_batch[image_id, :] + combined_rboxes = [] + for i in range(group_count): + segments_group = segments[np.where(group_indices == i)[0], :] + if segments_group.shape[0] > 0: + combined_rbox = combine_segs(segments_group) + combined_rboxes.append(combined_rbox) + combined_rboxes_batch.append(combined_rboxes) + combined_counts_batch.append(len(combined_rboxes)) + + max_count = np.max(combined_counts_batch) + for image_id in range(batch_size): + if not combined_counts_batch[image_id] == max_count: + combined_rboxes_pad = (max_count - combined_counts_batch[image_id] + ) * [RBOX_DIM * [0.0]] + combined_rboxes_batch[image_id] = np.vstack( + (combined_rboxes_batch[image_id], + np.array(combined_rboxes_pad))) + + return np.asarray(combined_rboxes_batch, + np.float32), np.asarray(combined_counts_batch, np.int32) + + +# combine_segments rewrite in python version +def combine_segments_python(segments, group_indices, segment_counts): + combined_rboxes, combined_counts = tf.py_func( + combine_segments_batch, [segments, group_indices, segment_counts], + [tf.float32, tf.int32]) + return combined_rboxes, combined_counts + + +# decode_segments_links rewrite in python version +def get_coord(offsets, map_size, offsets_defaults): + if offsets < offsets_defaults[1][0]: + l_idx = 0 + x = offsets % map_size[0][1] + y = offsets // map_size[0][1] + elif offsets < offsets_defaults[2][0]: + l_idx = 1 + x = (offsets - offsets_defaults[1][0]) % map_size[1][1] + y = (offsets - offsets_defaults[1][0]) // map_size[1][1] + elif offsets < offsets_defaults[3][0]: + l_idx = 2 + x = (offsets - offsets_defaults[2][0]) % map_size[2][1] + y = (offsets - offsets_defaults[2][0]) // map_size[2][1] + elif offsets < offsets_defaults[4][0]: + l_idx = 3 + x = (offsets - offsets_defaults[3][0]) % map_size[3][1] + y = (offsets - offsets_defaults[3][0]) // map_size[3][1] + elif offsets < offsets_defaults[5][0]: + l_idx = 4 + x = (offsets - offsets_defaults[4][0]) % map_size[4][1] + y = (offsets - offsets_defaults[4][0]) // map_size[4][1] + else: + l_idx = 5 + x = (offsets - offsets_defaults[5][0]) % map_size[5][1] + y = (offsets - offsets_defaults[5][0]) // map_size[5][1] + + return l_idx, x, y + + +def get_coord_link(offsets, map_size, offsets_defaults): + if offsets < offsets_defaults[1][1]: + offsets_node = offsets // N_LOCAL_LINKS + link_idx = offsets % N_LOCAL_LINKS + else: + offsets_node = (offsets - offsets_defaults[1][1]) // ( + N_LOCAL_LINKS + N_CROSS_LINKS) + offsets_defaults[1][0] + link_idx = (offsets - offsets_defaults[1][1]) % ( + N_LOCAL_LINKS + N_CROSS_LINKS) + l_idx, x, y = get_coord(offsets_node, map_size, offsets_defaults) + return l_idx, x, y, link_idx + + +def is_valid_coord(l_idx, x, y, map_size): + w = map_size[l_idx][1] + h = map_size[l_idx][0] + return x >= 0 and x < w and y >= 0 and y < h + + +def get_neighbours(l_idx, x, y, map_size, offsets_defaults): + if l_idx == 0: + coord = [(0, x - 1, y - 1), (0, x, y - 1), (0, x + 1, y - 1), + (0, x - 1, y), (0, x + 1, y), (0, x - 1, y + 1), + (0, x, y + 1), (0, x + 1, y + 1)] + else: + coord = [(l_idx, x - 1, y - 1), + (l_idx, x, y - 1), (l_idx, x + 1, y - 1), (l_idx, x - 1, y), + (l_idx, x + 1, y), (l_idx, x - 1, y + 1), (l_idx, x, y + 1), + (l_idx, x + 1, y + 1), (l_idx - 1, 2 * x, 2 * y), + (l_idx - 1, 2 * x + 1, 2 * y), (l_idx - 1, 2 * x, 2 * y + 1), + (l_idx - 1, 2 * x + 1, 2 * y + 1)] + neighbours_offsets = [] + link_idx = 0 + for nl_idx, nx, ny in coord: + if is_valid_coord(nl_idx, nx, ny, map_size): + neighbours_offset_node = offsets_defaults[nl_idx][ + 0] + map_size[nl_idx][1] * ny + nx + if l_idx == 0: + neighbours_offset_link = offsets_defaults[l_idx][1] + ( + map_size[l_idx][1] * y + x) * N_LOCAL_LINKS + link_idx + else: + off_tmp = (map_size[l_idx][1] * y + x) * ( + N_LOCAL_LINKS + N_CROSS_LINKS) + neighbours_offset_link = offsets_defaults[l_idx][ + 1] + off_tmp + link_idx + neighbours_offsets.append( + [neighbours_offset_node, neighbours_offset_link, link_idx]) + link_idx += 1 + # [node_offsets, link_offsets, link_idx(0-7/11)] + return neighbours_offsets + + +def decode_segments_links_python(image_size, all_nodes, all_links, all_reg, + anchor_sizes): + batch_size = 1 # FLAGS.test_batch_size + # offsets = 12285 #768 + all_nodes_flat = tf.concat( + [tf.reshape(o, [batch_size, -1, N_SEG_CLASSES]) for o in all_nodes], + axis=1) + all_links_flat = tf.concat( + [tf.reshape(o, [batch_size, -1, N_LNK_CLASSES]) for o in all_links], + axis=1) + all_reg_flat = tf.concat( + [tf.reshape(o, [batch_size, -1, OFFSET_DIM]) for o in all_reg], axis=1) + segments, group_indices, segment_counts, group_indices_all = tf.py_func( + decode_batch, [ + all_nodes_flat, all_links_flat, all_reg_flat, image_size, + tf.constant(anchor_sizes) + ], [tf.float32, tf.int32, tf.int32, tf.int32]) + return segments, group_indices, segment_counts, group_indices_all + + +def decode_segments_links_train(image_size, all_nodes, all_links, all_reg, + anchor_sizes): + batch_size = FLAGS.train_batch_size + # offsets = 12285 #768 + all_nodes_flat = tf.concat( + [tf.reshape(o, [batch_size, -1, N_SEG_CLASSES]) for o in all_nodes], + axis=1) + all_links_flat = tf.concat( + [tf.reshape(o, [batch_size, -1, N_LNK_CLASSES]) for o in all_links], + axis=1) + all_reg_flat = tf.concat( + [tf.reshape(o, [batch_size, -1, OFFSET_DIM]) for o in all_reg], axis=1) + segments, group_indices, segment_counts, group_indices_all = tf.py_func( + decode_batch, [ + all_nodes_flat, all_links_flat, all_reg_flat, image_size, + tf.constant(anchor_sizes) + ], [tf.float32, tf.int32, tf.int32, tf.int32]) + return segments, group_indices, segment_counts, group_indices_all + + +def decode_batch(all_nodes, all_links, all_reg, image_size, anchor_sizes): + batch_size = all_nodes.shape[0] + batch_segments = [] + batch_group_indices = [] + batch_segments_counts = [] + batch_group_indices_all = [] + for image_id in range(batch_size): + image_node_scores = all_nodes[image_id, :, :] + image_link_scores = all_links[image_id, :, :] + image_reg = all_reg[image_id, :, :] + image_segments, image_group_indices, image_segments_counts, image_group_indices_all = decode_image( + image_node_scores, image_link_scores, image_reg, image_size, + anchor_sizes) + batch_segments.append(image_segments) + batch_group_indices.append(image_group_indices) + batch_segments_counts.append(image_segments_counts) + batch_group_indices_all.append(image_group_indices_all) + max_count = np.max(batch_segments_counts) + for image_id in range(batch_size): + if not batch_segments_counts[image_id] == max_count: + batch_segments_pad = (max_count - batch_segments_counts[image_id] + ) * [OFFSET_DIM * [0.0]] + batch_segments[image_id] = np.vstack( + (batch_segments[image_id], np.array(batch_segments_pad))) + batch_group_indices[image_id] = np.hstack( + (batch_group_indices[image_id], + np.array( + (max_count - batch_segments_counts[image_id]) * [-1]))) + return np.asarray(batch_segments, np.float32), np.asarray( + batch_group_indices, + np.int32), np.asarray(batch_segments_counts, + np.int32), np.asarray(batch_group_indices_all, + np.int32) + + +def decode_image(image_node_scores, image_link_scores, image_reg, image_size, + anchor_sizes): + map_size = [] + offsets_defaults = [] + offsets_default_node = 0 + offsets_default_link = 0 + for i in range(N_DET_LAYERS): + offsets_defaults.append([offsets_default_node, offsets_default_link]) + map_size.append(image_size // (2**(2 + i))) + offsets_default_node += map_size[i][0] * map_size[i][1] + if i == 0: + offsets_default_link += map_size[i][0] * map_size[i][ + 1] * N_LOCAL_LINKS + else: + offsets_default_link += map_size[i][0] * map_size[i][1] * ( + N_LOCAL_LINKS + N_CROSS_LINKS) + + image_group_indices_all = decode_image_by_join(image_node_scores, + image_link_scores, + FLAGS.node_threshold, + FLAGS.link_threshold, + map_size, offsets_defaults) + image_group_indices_all -= 1 + image_group_indices = image_group_indices_all[np.where( + image_group_indices_all >= 0)[0]] + image_segments_counts = len(image_group_indices) + # convert image_reg to segments with scores(OFFSET_DIM+1) + image_segments = np.zeros((image_segments_counts, OFFSET_DIM), + dtype=np.float32) + for i, offsets in enumerate(np.where(image_group_indices_all >= 0)[0]): + encoded_cx = image_reg[offsets, 0] + encoded_cy = image_reg[offsets, 1] + encoded_width = image_reg[offsets, 2] + encoded_height = image_reg[offsets, 3] + encoded_theta_cos = image_reg[offsets, 4] + encoded_theta_sin = image_reg[offsets, 5] + + l_idx, x, y = get_coord(offsets, map_size, offsets_defaults) + rs = anchor_sizes[l_idx] + eps = 1e-6 + image_segments[i, 0] = encoded_cx * rs + (2**(2 + l_idx)) * (x + 0.5) + image_segments[i, 1] = encoded_cy * rs + (2**(2 + l_idx)) * (y + 0.5) + image_segments[i, 2] = np.exp(encoded_width) * rs - eps + image_segments[i, 3] = np.exp(encoded_height) * rs - eps + image_segments[i, 4] = encoded_theta_cos + image_segments[i, 5] = encoded_theta_sin + + return image_segments, image_group_indices, image_segments_counts, image_group_indices_all + + +def decode_image_by_join(node_scores, link_scores, node_threshold, + link_threshold, map_size, offsets_defaults): + node_mask = node_scores[:, POS_LABEL] >= node_threshold + link_mask = link_scores[:, POS_LABEL] >= link_threshold + group_mask = np.zeros_like(node_mask, np.int32) - 1 + offsets_pos = np.where(node_mask == 1)[0] + + def find_parent(point): + return group_mask[point] + + def set_parent(point, parent): + group_mask[point] = parent + + def is_root(point): + return find_parent(point) == -1 + + def find_root(point): + root = point + update_parent = False + while not is_root(root): + root = find_parent(root) + update_parent = True + + # for acceleration of find_root + if update_parent: + set_parent(point, root) + + return root + + def join(p1, p2): + root1 = find_root(p1) + root2 = find_root(p2) + + if root1 != root2: + set_parent(root1, root2) + + def get_all(): + root_map = {} + + def get_index(root): + if root not in root_map: + root_map[root] = len(root_map) + 1 + return root_map[root] + + mask = np.zeros_like(node_mask, dtype=np.int32) + for i, point in enumerate(offsets_pos): + point_root = find_root(point) + bbox_idx = get_index(point_root) + mask[point] = bbox_idx + return mask + + # join by link + pos_link = 0 + for i, offsets in enumerate(offsets_pos): + l_idx, x, y = get_coord(offsets, map_size, offsets_defaults) + neighbours = get_neighbours(l_idx, x, y, map_size, offsets_defaults) + for n_idx, noffsets in enumerate(neighbours): + link_value = link_mask[noffsets[1]] + node_cls = node_mask[noffsets[0]] + if link_value and node_cls: + pos_link += 1 + join(offsets, noffsets[0]) + # print(pos_link) + mask = get_all() + return mask + + +def get_link_mask(node_mask, offsets_defaults, link_max): + link_mask = np.zeros_like(link_max) + link_mask[0:offsets_defaults[1][1]] = np.tile( + node_mask[0:offsets_defaults[1][0]], + (N_LOCAL_LINKS, 1)).transpose().reshape(offsets_defaults[1][1]) + link_mask[offsets_defaults[1][1]:offsets_defaults[2][1]] = np.tile( + node_mask[offsets_defaults[1][0]:offsets_defaults[2][0]], + (N_LOCAL_LINKS + N_CROSS_LINKS, 1)).transpose().reshape( + (offsets_defaults[2][1] - offsets_defaults[1][1])) + link_mask[offsets_defaults[2][1]:offsets_defaults[3][1]] = np.tile( + node_mask[offsets_defaults[2][0]:offsets_defaults[3][0]], + (N_LOCAL_LINKS + N_CROSS_LINKS, 1)).transpose().reshape( + (offsets_defaults[3][1] - offsets_defaults[2][1])) + link_mask[offsets_defaults[3][1]:offsets_defaults[4][1]] = np.tile( + node_mask[offsets_defaults[3][0]:offsets_defaults[4][0]], + (N_LOCAL_LINKS + N_CROSS_LINKS, 1)).transpose().reshape( + (offsets_defaults[4][1] - offsets_defaults[3][1])) + link_mask[offsets_defaults[4][1]:offsets_defaults[5][1]] = np.tile( + node_mask[offsets_defaults[4][0]:offsets_defaults[5][0]], + (N_LOCAL_LINKS + N_CROSS_LINKS, 1)).transpose().reshape( + (offsets_defaults[5][1] - offsets_defaults[4][1])) + link_mask[offsets_defaults[5][1]:] = np.tile( + node_mask[offsets_defaults[5][0]:], + (N_LOCAL_LINKS + N_CROSS_LINKS, 1)).transpose().reshape( + (len(link_mask) - offsets_defaults[5][1])) + + return link_mask + + +def get_link8(link_scores_raw, map_size): + # link[i-1] -local- start -16- end -cross- link[i] + link8_mask = np.zeros((link_scores_raw.shape[0])) + for i in range(N_DET_LAYERS): + if i == 0: + offsets_start = map_size[i][0] * map_size[i][1] * N_LOCAL_LINKS + offsets_end = map_size[i][0] * map_size[i][1] * ( + N_LOCAL_LINKS + 16) + offsets_link = map_size[i][0] * map_size[i][1] * ( + N_LOCAL_LINKS + 16) + link8_mask[:offsets_start] = 1 + else: + offsets_start = offsets_link + map_size[i][0] * map_size[i][ + 1] * N_LOCAL_LINKS + offsets_end = offsets_link + map_size[i][0] * map_size[i][1] * ( + N_LOCAL_LINKS + 16) + offsets_link_pre = offsets_link + offsets_link += map_size[i][0] * map_size[i][1] * ( + N_LOCAL_LINKS + 16 + N_CROSS_LINKS) + link8_mask[offsets_link_pre:offsets_start] = 1 + link8_mask[offsets_end:offsets_link] = 1 + return link_scores_raw[np.where(link8_mask > 0)[0], :] + + +def decode_image_by_mutex(node_scores, link_scores, node_threshold, + link_threshold, map_size, offsets_defaults): + node_mask = node_scores[:, POS_LABEL] >= node_threshold + link_pos = link_scores[:, POS_LABEL] + link_mut = link_scores[:, MUT_LABEL] + link_max = np.max(np.vstack((link_pos, link_mut)), axis=0) + + offsets_pos_list = np.where(node_mask == 1)[0].tolist() + + link_mask_th = link_max >= link_threshold + link_mask = get_link_mask(node_mask, offsets_defaults, link_max) + offsets_link_max = np.argsort(-(link_max * link_mask * link_mask_th)) + offsets_link_max = offsets_link_max[:len(offsets_pos_list) * 8] + + group_mask = np.zeros_like(node_mask, dtype=np.int32) - 1 + mutex_mask = len(node_mask) * [[]] + + def find_parent(point): + return group_mask[point] + + def set_parent(point, parent): + group_mask[point] = parent + + def set_mutex_constraint(point, mutex_point_list): + mutex_mask[point] = mutex_point_list + + def find_mutex_constraint(point): + mutex_point_list = mutex_mask[point] + # update mutex_point_list + mutex_point_list_new = [] + if not mutex_point_list == []: + for mutex_point in mutex_point_list: + if not is_root(mutex_point): + mutex_point = find_root(mutex_point) + if mutex_point not in mutex_point_list_new: + mutex_point_list_new.append(mutex_point) + set_mutex_constraint(point, mutex_point_list_new) + return mutex_point_list_new + + def combine_mutex_constraint(point, parent): + mutex_point_list = find_mutex_constraint(point) + mutex_parent_list = find_mutex_constraint(parent) + for mutex_point in mutex_point_list: + if not is_root(mutex_point): + mutex_point = find_root(mutex_point) + if mutex_point not in mutex_parent_list: + mutex_parent_list.append(mutex_point) + set_mutex_constraint(parent, mutex_parent_list) + + def add_mutex_constraint(p1, p2): + mutex_point_list1 = find_mutex_constraint(p1) + mutex_point_list2 = find_mutex_constraint(p2) + + if p1 not in mutex_point_list2: + mutex_point_list2.append(p1) + if p2 not in mutex_point_list1: + mutex_point_list1.append(p2) + set_mutex_constraint(p1, mutex_point_list1) + set_mutex_constraint(p2, mutex_point_list2) + + def is_root(point): + return find_parent(point) == -1 + + def find_root(point): + root = point + update_parent = False + while not is_root(root): + root = find_parent(root) + update_parent = True + + # for acceleration of find_root + if update_parent: + set_parent(point, root) + + return root + + def join(p1, p2): + root1 = find_root(p1) + root2 = find_root(p2) + + if root1 != root2 and (root1 not in find_mutex_constraint(root2)): + set_parent(root1, root2) + combine_mutex_constraint(root1, root2) + + def disjoin(p1, p2): + root1 = find_root(p1) + root2 = find_root(p2) + + if root1 != root2: + add_mutex_constraint(root1, root2) + + def get_all(): + root_map = {} + + def get_index(root): + if root not in root_map: + root_map[root] = len(root_map) + 1 + return root_map[root] + + mask = np.zeros_like(node_mask, dtype=np.int32) + for _, point in enumerate(offsets_pos_list): + point_root = find_root(point) + bbox_idx = get_index(point_root) + mask[point] = bbox_idx + return mask + + # join by link + pos_link = 0 + mut_link = 0 + for _, offsets_link in enumerate(offsets_link_max): + l_idx, x, y, link_idx = get_coord_link(offsets_link, map_size, + offsets_defaults) + offsets = offsets_defaults[l_idx][0] + map_size[l_idx][1] * y + x + if offsets in offsets_pos_list: + neighbours = get_neighbours(l_idx, x, y, map_size, + offsets_defaults) + if not len(np.where(np.array(neighbours)[:, + 2] == link_idx)[0]) == 0: + noffsets = neighbours[np.where( + np.array(neighbours)[:, 2] == link_idx)[0][0]] + link_pos_value = link_pos[noffsets[1]] + link_mut_value = link_mut[noffsets[1]] + node_cls = node_mask[noffsets[0]] + if node_cls and (link_pos_value > link_mut_value): + pos_link += 1 + join(offsets, noffsets[0]) + elif node_cls and (link_pos_value < link_mut_value): + mut_link += 1 + disjoin(offsets, noffsets[0]) + + mask = get_all() + return mask diff --git a/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py b/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py new file mode 100644 index 00000000..6371d4e5 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py @@ -0,0 +1,432 @@ +"""Contains definitions for the original form of Residual Networks. +The 'v1' residual networks (ResNets) implemented in this module were proposed +by: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +Other variants were introduced in: +[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 +The networks defined in this module utilize the bottleneck building block of +[1] with projection shortcuts only for increasing depths. They employ batch +normalization *after* every weight layer. This is the architecture used by +MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and +ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1' +architecture and the alternative 'v2' architecture of [2] which uses batch +normalization *before* every weight layer in the so-called full pre-activation +units. +Typical use: + from tensorflow.contrib.slim.nets import resnet_v1 +ResNet-101 for image classification into 1000 classes: + # inputs has shape [batch, 224, 224, 3] + with slim.arg_scope(resnet_v1.resnet_arg_scope()): + net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False) +ResNet-101 for semantic segmentation into 21 classes: + # inputs has shape [batch, 513, 513, 3] + with slim.arg_scope(resnet_v1.resnet_arg_scope()): + net, end_points = resnet_v1.resnet_v1_101(inputs, + 21, + is_training=False, + global_pool=False, + output_stride=16) +""" +import tensorflow as tf +import tf_slim as slim + +from . import resnet_utils + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + +resnet_arg_scope = resnet_utils.resnet_arg_scope + + +@slim.add_arg_scope +def basicblock(inputs, + depth, + depth_bottleneck, + stride, + rate=1, + outputs_collections=None, + scope=None): + """Bottleneck residual unit variant with BN after convolutions. + This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for + its definition. Note that we use here the bottleneck variant which has an + extra bottleneck layer. + When putting together two consecutive ResNet blocks that use this unit, one + should use stride = 2 in the last unit of the first block. + Args: + inputs: A tensor of size [batch, height, width, channels]. + depth: The depth of the ResNet unit output. + depth_bottleneck: The depth of the bottleneck layers. + stride: The ResNet unit's stride. Determines the amount of downsampling of + the units output compared to its input. + rate: An integer, rate for atrous convolution. + outputs_collections: Collection to add the ResNet unit output. + scope: Optional variable_scope. + Returns: + The ResNet unit's output. + """ + with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: + depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) + if depth == depth_in: + shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') + else: + shortcut = slim.conv2d( + inputs, + depth, [1, 1], + stride=stride, + activation_fn=None, + scope='shortcut') + + residual = resnet_utils.conv2d_same( + inputs, depth, 3, stride, rate=rate, scope='conv1') + residual = resnet_utils.conv2d_same( + residual, depth, 3, 1, rate=rate, scope='conv2') + + output = tf.nn.relu(residual + shortcut) + + return slim.utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, output) + + +@slim.add_arg_scope +def bottleneck(inputs, + depth, + depth_bottleneck, + stride, + rate=1, + outputs_collections=None, + scope=None): + """Bottleneck residual unit variant with BN after convolutions. + This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for + its definition. Note that we use here the bottleneck variant which has an + extra bottleneck layer. + When putting together two consecutive ResNet blocks that use this unit, one + should use stride = 2 in the last unit of the first block. + Args: + inputs: A tensor of size [batch, height, width, channels]. + depth: The depth of the ResNet unit output. + depth_bottleneck: The depth of the bottleneck layers. + stride: The ResNet unit's stride. Determines the amount of downsampling of + the units output compared to its input. + rate: An integer, rate for atrous convolution. + outputs_collections: Collection to add the ResNet unit output. + scope: Optional variable_scope. + Returns: + The ResNet unit's output. + """ + with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: + depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) + if depth == depth_in: + shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') + else: + shortcut = slim.conv2d( + inputs, + depth, [1, 1], + stride=stride, + activation_fn=None, + scope='shortcut') + + residual = slim.conv2d( + inputs, depth_bottleneck, [1, 1], stride=1, scope='conv1') + residual = resnet_utils.conv2d_same( + residual, depth_bottleneck, 3, stride, rate=rate, scope='conv2') + residual = slim.conv2d( + residual, + depth, [1, 1], + stride=1, + activation_fn=None, + scope='conv3') + + output = tf.nn.relu(shortcut + residual) + + return slim.utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, output) + + +def resnet_v1(inputs, + blocks, + num_classes=None, + is_training=True, + global_pool=True, + output_stride=None, + include_root_block=True, + spatial_squeeze=True, + reuse=None, + scope=None): + """Generator for v1 ResNet models. + This function generates a family of ResNet v1 models. See the resnet_v1_*() + methods for specific model instantiations, obtained by selecting different + block instantiations that produce ResNets of various depths. + Training for image classification on Imagenet is usually done with [224, 224] + inputs, resulting in [7, 7] feature maps at the output of the last ResNet + block for the ResNets defined in [1] that have nominal stride equal to 32. + However, for dense prediction tasks we advise that one uses inputs with + spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In + this case the feature maps at the ResNet output will have spatial shape + [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] + and corners exactly aligned with the input image corners, which greatly + facilitates alignment of the features to the image. Using as input [225, 225] + images results in [8, 8] feature maps at the output of the last ResNet block. + For dense prediction tasks, the ResNet needs to run in fully-convolutional + (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all + have nominal stride equal to 32 and a good choice in FCN mode is to use + output_stride=16 in order to increase the density of the computed features at + small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. + Args: + inputs: A tensor of size [batch, height_in, width_in, channels]. + blocks: A list of length equal to the number of ResNet blocks. Each element + is a resnet_utils.Block object describing the units in the block. + num_classes: Number of predicted classes for classification tasks. If None + we return the features before the logit layer. + is_training: whether is training or not. + global_pool: If True, we perform global average pooling before computing the + logits. Set to True for image classification, False for dense prediction. + output_stride: If None, then the output will be computed at the nominal + network stride. If output_stride is not None, it specifies the requested + ratio of input to output spatial resolution. + include_root_block: If True, include the initial convolution followed by + max-pooling, if False excludes it. + spatial_squeeze: if True, logits is of shape [B, C], if false logits is + of shape [B, 1, 1, C], where B is batch_size and C is number of classes. + reuse: whether or not the network and its variables should be reused. To be + able to reuse 'scope' must be given. + scope: Optional variable_scope. + Returns: + net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. + If global_pool is False, then height_out and width_out are reduced by a + factor of output_stride compared to the respective height_in and width_in, + else both height_out and width_out equal one. If num_classes is None, then + net is the output of the last ResNet block, potentially after global + average pooling. If num_classes is not None, net contains the pre-softmax + activations. + end_points: A dictionary from components of the network to the corresponding + activation. + Raises: + ValueError: If the target output_stride is not valid. + """ + with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc: + end_points_collection = sc.name + '_end_points' + with slim.arg_scope( + [slim.conv2d, bottleneck, resnet_utils.stack_blocks_dense], + outputs_collections=end_points_collection): + with slim.arg_scope([slim.batch_norm], is_training=is_training): + net = inputs + if include_root_block: + if output_stride is not None: + if output_stride % 4 != 0: + raise ValueError( + 'The output_stride needs to be a multiple of 4.' + ) + output_stride /= 4 + net = resnet_utils.conv2d_same( + net, 64, 7, stride=2, scope='conv1') + net = tf.pad(net, [[0, 0], [1, 1], [1, 1], [0, 0]]) + net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') + + net = slim.utils.collect_named_outputs( + end_points_collection, 'pool2', net) + + net = resnet_utils.stack_blocks_dense(net, blocks, + output_stride) + + end_points = slim.utils.convert_collection_to_dict( + end_points_collection) + + end_points['pool1'] = end_points['resnet_v1_18/block2/unit_2'] + end_points['pool2'] = end_points['resnet_v1_18/block3/unit_2'] + end_points['pool3'] = end_points['resnet_v1_18/block4/unit_2'] + end_points['pool4'] = end_points['resnet_v1_18/block5/unit_2'] + end_points['pool5'] = end_points['resnet_v1_18/block6/unit_2'] + end_points['pool6'] = net + + return net, end_points + + +resnet_v1.default_image_size = 224 + + +def resnet_v1_18(inputs, + num_classes=None, + is_training=True, + global_pool=True, + output_stride=None, + spatial_squeeze=True, + reuse=None, + scope='resnet_v1_18'): + """ResNet-18 model of [1]. See resnet_v1() for arg and return description.""" + blocks = [ + resnet_utils.Block('block1', basicblock, + [(64, 64, 1)] + [(64, 64, 1)]), + resnet_utils.Block('block2', basicblock, + [(128, 128, 1)] + [(128, 128, 1)]), + resnet_utils.Block('block3', basicblock, + [(256, 256, 2)] + [(256, 256, 1)]), + resnet_utils.Block('block4', basicblock, + [(512, 512, 2)] + [(512, 512, 1)]), + resnet_utils.Block('block5', basicblock, + [(256, 256, 2)] + [(256, 256, 1)]), + resnet_utils.Block('block6', basicblock, + [(256, 256, 2)] + [(256, 256, 1)]), + resnet_utils.Block('block7', basicblock, + [(256, 256, 2)] + [(256, 256, 1)]), + ] + return resnet_v1( + inputs, + blocks, + num_classes, + is_training, + global_pool=global_pool, + output_stride=output_stride, + include_root_block=True, + spatial_squeeze=spatial_squeeze, + reuse=reuse, + scope=scope) + + +resnet_v1_18.default_image_size = resnet_v1.default_image_size + + +def resnet_v1_50(inputs, + num_classes=None, + is_training=True, + global_pool=True, + output_stride=None, + spatial_squeeze=True, + reuse=None, + scope='resnet_v1_50'): + """ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" + blocks = [ + resnet_utils.Block('block1', bottleneck, + [(256, 64, 1)] * 2 + [(256, 64, 2)]), + resnet_utils.Block('block2', bottleneck, + [(512, 128, 1)] * 3 + [(512, 128, 2)]), + resnet_utils.Block('block3', bottleneck, + [(1024, 256, 1)] * 5 + [(1024, 256, 2)]), + resnet_utils.Block('block4', bottleneck, + [(2048, 512, 1)] * 3 + [(2048, 512, 2)]), + resnet_utils.Block('block5', bottleneck, + [(1024, 256, 1)] * 2 + [(1024, 256, 2)]), + resnet_utils.Block('block6', bottleneck, [(1024, 256, 1)] * 2), + ] + return resnet_v1( + inputs, + blocks, + num_classes, + is_training, + global_pool=global_pool, + output_stride=output_stride, + include_root_block=True, + spatial_squeeze=spatial_squeeze, + reuse=reuse, + scope=scope) + + +resnet_v1_50.default_image_size = resnet_v1.default_image_size + + +def resnet_v1_101(inputs, + num_classes=None, + is_training=True, + global_pool=True, + output_stride=None, + spatial_squeeze=True, + reuse=None, + scope='resnet_v1_101'): + """ResNet-101 model of [1]. See resnet_v1() for arg and return description.""" + blocks = [ + resnet_utils.Block('block1', bottleneck, + [(256, 64, 1)] * 2 + [(256, 64, 2)]), + resnet_utils.Block('block2', bottleneck, + [(512, 128, 1)] * 3 + [(512, 128, 2)]), + resnet_utils.Block('block3', bottleneck, + [(1024, 256, 1)] * 22 + [(1024, 256, 2)]), + resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3) + ] + return resnet_v1( + inputs, + blocks, + num_classes, + is_training, + global_pool=global_pool, + output_stride=output_stride, + include_root_block=True, + spatial_squeeze=spatial_squeeze, + reuse=reuse, + scope=scope) + + +resnet_v1_101.default_image_size = resnet_v1.default_image_size + + +def resnet_v1_152(inputs, + num_classes=None, + is_training=True, + global_pool=True, + output_stride=None, + spatial_squeeze=True, + reuse=None, + scope='resnet_v1_152'): + """ResNet-152 model of [1]. See resnet_v1() for arg and return description.""" + blocks = [ + resnet_utils.Block('block1', bottleneck, + [(256, 64, 1)] * 2 + [(256, 64, 2)]), + resnet_utils.Block('block2', bottleneck, + [(512, 128, 1)] * 7 + [(512, 128, 2)]), + resnet_utils.Block('block3', bottleneck, + [(1024, 256, 1)] * 35 + [(1024, 256, 2)]), + resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3) + ] + return resnet_v1( + inputs, + blocks, + num_classes, + is_training, + global_pool=global_pool, + output_stride=output_stride, + include_root_block=True, + spatial_squeeze=spatial_squeeze, + reuse=reuse, + scope=scope) + + +resnet_v1_152.default_image_size = resnet_v1.default_image_size + + +def resnet_v1_200(inputs, + num_classes=None, + is_training=True, + global_pool=True, + output_stride=None, + spatial_squeeze=True, + reuse=None, + scope='resnet_v1_200'): + """ResNet-200 model of [2]. See resnet_v1() for arg and return description.""" + blocks = [ + resnet_utils.Block('block1', bottleneck, + [(256, 64, 1)] * 2 + [(256, 64, 2)]), + resnet_utils.Block('block2', bottleneck, + [(512, 128, 1)] * 23 + [(512, 128, 2)]), + resnet_utils.Block('block3', bottleneck, + [(1024, 256, 1)] * 35 + [(1024, 256, 2)]), + resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3) + ] + return resnet_v1( + inputs, + blocks, + num_classes, + is_training, + global_pool=global_pool, + output_stride=output_stride, + include_root_block=True, + spatial_squeeze=spatial_squeeze, + reuse=reuse, + scope=scope) + + +resnet_v1_200.default_image_size = resnet_v1.default_image_size + +if __name__ == '__main__': + input = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name='input') + with slim.arg_scope(resnet_arg_scope()) as sc: + logits = resnet_v1_50(input) diff --git a/modelscope/pipelines/cv/ocr_utils/resnet_utils.py b/modelscope/pipelines/cv/ocr_utils/resnet_utils.py new file mode 100644 index 00000000..e0e240c8 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/resnet_utils.py @@ -0,0 +1,231 @@ +"""Contains building blocks for various versions of Residual Networks. +Residual networks (ResNets) were proposed in: + Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015 +More variants were introduced in: + Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016 +We can obtain different ResNet variants by changing the network depth, width, +and form of residual unit. This module implements the infrastructure for +building them. Concrete ResNet units and full ResNet networks are implemented in +the accompanying resnet_v1.py and resnet_v2.py modules. +Compared to https://github.com/KaimingHe/deep-residual-networks, in the current +implementation we subsample the output activations in the last residual unit of +each block, instead of subsampling the input activations in the first residual +unit of each block. The two implementations give identical results but our +implementation is more memory efficient. +""" + +import collections + +import tensorflow as tf +import tf_slim as slim + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + + +class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): + """A named tuple describing a ResNet block. + Its parts are: + scope: The scope of the `Block`. + unit_fn: The ResNet unit function which takes as input a `Tensor` and + returns another `Tensor` with the output of the ResNet unit. + args: A list of length equal to the number of units in the `Block`. The list + contains one (depth, depth_bottleneck, stride) tuple for each unit in the + block to serve as argument to unit_fn. + """ + + +def subsample(inputs, factor, scope=None): + """Subsamples the input along the spatial dimensions. + Args: + inputs: A `Tensor` of size [batch, height_in, width_in, channels]. + factor: The subsampling factor. + scope: Optional variable_scope. + Returns: + output: A `Tensor` of size [batch, height_out, width_out, channels] with the + input, either intact (if factor == 1) or subsampled (if factor > 1). + """ + if factor == 1: + return inputs + else: + return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) + + +def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): + """Strided 2-D convolution with 'SAME' padding. + When stride > 1, then we do explicit zero-padding, followed by conv2d with + 'VALID' padding. + Note that + net = conv2d_same(inputs, num_outputs, 3, stride=stride) + is equivalent to + net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') + net = subsample(net, factor=stride) + whereas + net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') + is different when the input's height or width is even, which is why we add the + current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). + Args: + inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. + num_outputs: An integer, the number of output filters. + kernel_size: An int with the kernel_size of the filters. + stride: An integer, the output stride. + rate: An integer, rate for atrous convolution. + scope: Scope. + Returns: + output: A 4-D tensor of size [batch, height_out, width_out, channels] with + the convolution output. + """ + if stride == 1: + return slim.conv2d( + inputs, + num_outputs, + kernel_size, + stride=1, + rate=rate, + padding='SAME', + scope=scope) + else: + kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) + pad_total = kernel_size_effective - 1 + pad_beg = pad_total // 2 + pad_end = pad_total - pad_beg + inputs = tf.pad( + inputs, [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) + return slim.conv2d( + inputs, + num_outputs, + kernel_size, + stride=stride, + rate=rate, + padding='VALID', + scope=scope) + + +@slim.add_arg_scope +def stack_blocks_dense(net, + blocks, + output_stride=None, + outputs_collections=None): + """Stacks ResNet `Blocks` and controls output feature density. + First, this function creates scopes for the ResNet in the form of + 'block_name/unit_1', 'block_name/unit_2', etc. + Second, this function allows the user to explicitly control the ResNet + output_stride, which is the ratio of the input to output spatial resolution. + This is useful for dense prediction tasks such as semantic segmentation or + object detection. + Most ResNets consist of 4 ResNet blocks and subsample the activations by a + factor of 2 when transitioning between consecutive ResNet blocks. This results + to a nominal ResNet output_stride equal to 8. If we set the output_stride to + half the nominal network stride (e.g., output_stride=4), then we compute + responses twice. + Control of the output feature density is implemented by atrous convolution. + Args: + net: A `Tensor` of size [batch, height, width, channels]. + blocks: A list of length equal to the number of ResNet `Blocks`. Each + element is a ResNet `Block` object describing the units in the `Block`. + output_stride: If `None`, then the output will be computed at the nominal + network stride. If output_stride is not `None`, it specifies the requested + ratio of input to output spatial resolution, which needs to be equal to + the product of unit strides from the start up to some level of the ResNet. + For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, + then valid values for the output_stride are 1, 2, 6, 24 or None (which + is equivalent to output_stride=24). + outputs_collections: Collection to add the ResNet block outputs. + Returns: + net: Output tensor with stride equal to the specified output_stride. + Raises: + ValueError: If the target output_stride is not valid. + """ + # The current_stride variable keeps track of the effective stride of the + # activations. This allows us to invoke atrous convolution whenever applying + # the next residual unit would result in the activations having stride larger + # than the target output_stride. + current_stride = 1 + + # The atrous convolution rate parameter. + rate = 1 + + for block in blocks: + with tf.variable_scope(block.scope, 'block', [net]): + for i, unit in enumerate(block.args): + if output_stride is not None and current_stride > output_stride: + raise ValueError( + 'The target output_stride cannot be reached.') + + with tf.variable_scope( + 'unit_%d' % (i + 1), values=[net]) as sc: + unit_depth, unit_depth_bottleneck, unit_stride = unit + # If we have reached the target output_stride, then we need to employ + # atrous convolution with stride=1 and multiply the atrous rate by the + # current unit's stride for use in subsequent layers. + if output_stride is not None and current_stride == output_stride: + net = block.unit_fn( + net, + depth=unit_depth, + depth_bottleneck=unit_depth_bottleneck, + stride=1, + rate=rate) + rate *= unit_stride + + else: + net = block.unit_fn( + net, + depth=unit_depth, + depth_bottleneck=unit_depth_bottleneck, + stride=unit_stride, + rate=1) + current_stride *= unit_stride + net = slim.utils.collect_named_outputs( + outputs_collections, sc.name, net) + + if output_stride is not None and current_stride != output_stride: + raise ValueError('The target output_stride cannot be reached.') + + return net + + +def resnet_arg_scope(weight_decay=0.0001, + batch_norm_decay=0.997, + batch_norm_epsilon=1e-5, + batch_norm_scale=True): + """Defines the default ResNet arg scope. + TODO(gpapan): The batch-normalization related default values above are + appropriate for use in conjunction with the reference ResNet models + released at https://github.com/KaimingHe/deep-residual-networks. When + training ResNets from scratch, they might need to be tuned. + Args: + weight_decay: The weight decay to use for regularizing the model. + batch_norm_decay: The moving average decay when estimating layer activation + statistics in batch normalization. + batch_norm_epsilon: Small constant to prevent division by zero when + normalizing activations by their variance in batch normalization. + batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the + activations in the batch normalization layer. + Returns: + An `arg_scope` to use for the resnet models. + """ + batch_norm_params = { + 'decay': batch_norm_decay, + 'epsilon': batch_norm_epsilon, + 'scale': batch_norm_scale, + 'updates_collections': tf.GraphKeys.UPDATE_OPS, + } + + with slim.arg_scope( + [slim.conv2d], + weights_regularizer=slim.l2_regularizer(weight_decay), + weights_initializer=slim.variance_scaling_initializer(), + activation_fn=tf.nn.relu, + normalizer_fn=slim.batch_norm, + normalizer_params=batch_norm_params): + with slim.arg_scope([slim.batch_norm], **batch_norm_params): + # The following implies padding='SAME' for pool1, which makes feature + # alignment easier for dense prediction tasks. This is also used in + # https://github.com/facebook/fb.resnet.torch. However the accompanying + # code of 'Deep Residual Learning for Image Recognition' uses + # padding='VALID' for pool1. You can switch to that choice by setting + # slim.arg_scope([slim.max_pool2d], padding='VALID'). + with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: + return arg_sc diff --git a/modelscope/pipelines/cv/ocr_utils/utils.py b/modelscope/pipelines/cv/ocr_utils/utils.py new file mode 100644 index 00000000..be8e3371 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/utils.py @@ -0,0 +1,108 @@ +import cv2 +import numpy as np + + +def rboxes_to_polygons(rboxes): + """ + Convert rboxes to polygons + ARGS + `rboxes`: [n, 5] + RETURN + `polygons`: [n, 8] + """ + + theta = rboxes[:, 4:5] + cxcy = rboxes[:, :2] + half_w = rboxes[:, 2:3] / 2. + half_h = rboxes[:, 3:4] / 2. + v1 = np.hstack([np.cos(theta) * half_w, np.sin(theta) * half_w]) + v2 = np.hstack([-np.sin(theta) * half_h, np.cos(theta) * half_h]) + p1 = cxcy - v1 - v2 + p2 = cxcy + v1 - v2 + p3 = cxcy + v1 + v2 + p4 = cxcy - v1 + v2 + polygons = np.hstack([p1, p2, p3, p4]) + return polygons + + +def cal_width(box): + pd1 = point_dist(box[0], box[1], box[2], box[3]) + pd2 = point_dist(box[4], box[5], box[6], box[7]) + return (pd1 + pd2) / 2 + + +def point_dist(x1, y1, x2, y2): + return np.sqrt((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)) + + +def draw_polygons(img, polygons): + for p in polygons.tolist(): + p = [int(o) for o in p] + cv2.line(img, (p[0], p[1]), (p[2], p[3]), (0, 255, 0), 1) + cv2.line(img, (p[2], p[3]), (p[4], p[5]), (0, 255, 0), 1) + cv2.line(img, (p[4], p[5]), (p[6], p[7]), (0, 255, 0), 1) + cv2.line(img, (p[6], p[7]), (p[0], p[1]), (0, 255, 0), 1) + return img + + +def nms_python(boxes): + boxes = sorted(boxes, key=lambda x: -x[8]) + nms_flag = [True] * len(boxes) + for i, a in enumerate(boxes): + if not nms_flag[i]: + continue + else: + for j, b in enumerate(boxes): + if not j > i: + continue + if not nms_flag[j]: + continue + score_a = a[8] + score_b = b[8] + rbox_a = polygon2rbox(a[:8]) + rbox_b = polygon2rbox(b[:8]) + if point_in_rbox(rbox_a[:2], rbox_b) or point_in_rbox( + rbox_b[:2], rbox_a): + if score_a > score_b: + nms_flag[j] = False + boxes_nms = [] + for i, box in enumerate(boxes): + if nms_flag[i]: + boxes_nms.append(box) + return boxes_nms + + +def point_in_rbox(c, rbox): + cx0, cy0 = c[0], c[1] + cx1, cy1 = rbox[0], rbox[1] + w, h = rbox[2], rbox[3] + theta = rbox[4] + dist_x = np.abs((cx1 - cx0) * np.cos(theta) + (cy1 - cy0) * np.sin(theta)) + dist_y = np.abs(-(cx1 - cx0) * np.sin(theta) + (cy1 - cy0) * np.cos(theta)) + return ((dist_x < w / 2.0) and (dist_y < h / 2.0)) + + +def polygon2rbox(polygon): + x1, x2, x3, x4 = polygon[0], polygon[2], polygon[4], polygon[6] + y1, y2, y3, y4 = polygon[1], polygon[3], polygon[5], polygon[7] + c_x = (x1 + x2 + x3 + x4) / 4 + c_y = (y1 + y2 + y3 + y4) / 4 + w1 = point_dist(x1, y1, x2, y2) + w2 = point_dist(x3, y3, x4, y4) + h1 = point_line_dist(c_x, c_y, x1, y1, x2, y2) + h2 = point_line_dist(c_x, c_y, x3, y3, x4, y4) + h = h1 + h2 + w = (w1 + w2) / 2 + theta1 = np.arctan2(y2 - y1, x2 - x1) + theta2 = np.arctan2(y3 - y4, x3 - x4) + theta = (theta1 + theta2) / 2.0 + return [c_x, c_y, w, h, theta] + + +def point_line_dist(px, py, x1, y1, x2, y2): + eps = 1e-6 + dx = x2 - x1 + dy = y2 - y1 + div = np.sqrt(dx * dx + dy * dy) + eps + dist = np.abs(px * dy - py * dx + x2 * y1 - y2 * x1) / div + return dist diff --git a/modelscope/pipelines/outputs.py b/modelscope/pipelines/outputs.py index 15d8a995..d7bdfd29 100644 --- a/modelscope/pipelines/outputs.py +++ b/modelscope/pipelines/outputs.py @@ -54,6 +54,13 @@ TASK_OUTPUTS = { # } Tasks.pose_estimation: ['poses', 'boxes'], + # ocr detection result for single sample + # { + # "det_polygons": np.array with shape [num_text, 8], each box is + # [x1, y1, x2, y2, x3, y3, x4, y4] + # } + Tasks.ocr_detection: ['det_polygons'], + # ============ nlp tasks =================== # text classification result for single sample diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 61049734..c26a9e24 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -28,6 +28,7 @@ class Tasks(object): image_editing = 'image-editing' image_generation = 'image-generation' image_matting = 'image-matting' + ocr_detection = 'ocr-detection' # nlp tasks word_segmentation = 'word-segmentation' diff --git a/requirements/cv.txt b/requirements/cv.txt index 66799b76..5bec8ba7 100644 --- a/requirements/cv.txt +++ b/requirements/cv.txt @@ -1 +1,2 @@ easydict +tf_slim diff --git a/tests/pipelines/test_ocr_detection.py b/tests/pipelines/test_ocr_detection.py new file mode 100644 index 00000000..62fcedd3 --- /dev/null +++ b/tests/pipelines/test_ocr_detection.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import shutil +import sys +import tempfile +import unittest +from typing import Any, Dict, List, Tuple, Union + +import cv2 +import numpy as np +import PIL + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class OCRDetectionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_resnet18_ocr-detection-line-level_damo' + self.test_image = 'data/test/images/ocr_detection.jpg' + + def pipeline_inference(self, pipeline: Pipeline, input_location: str): + result = pipeline(input_location) + print('ocr detection results: ') + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_modelhub_default_model(self): + ocr_detection = pipeline(Tasks.ocr_detection) + self.pipeline_inference(ocr_detection, self.test_image) + + +if __name__ == '__main__': + unittest.main()