From d4692b5ada73936b119daf440f9409f06524c04d Mon Sep 17 00:00:00 2001 From: "xixing.tj" Date: Tue, 28 Jun 2022 14:03:01 +0800 Subject: [PATCH] [to #42322933]Merge branch 'master' into ocr/ocr_detection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复master分支ocr_detection 单元测试bug Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9112290 * create ocr_detection task * fix code check error * fix code check error * fix code check issue * fix code check issue * replace c++ nms with python version * fix code check issue * fix code check issue * rename maas_lib * merge master to ocr/ocr_detection * add model_hub sup for ocr_detection * fix bug * replace c++ decoder with python version * fix bug * Merge branch 'master' into ocr/ocr_detection * merge master * fix code check * update * add requirements for ocr_detection * fix model_hub fetch bug * remove debug code * Merge branch 'master' into ocr/ocr_detection * add local test image for ocr_detection * update requirements for model_hub * Merge branch 'master' into ocr/ocr_detection * fix bug for full case test * remove ema for ocr_detection * Merge branch 'master' into ocr/ocr_detection * apply ocr_detection test case * Merge branch 'master' into ocr/ocr_detection * update slim dependency for ocr_detection * add more test case for ocr_detection * release tf graph before create * recover ema for ocr_detection model * fix code * Merge branch 'master' into ocr/ocr_detection * fix code --- .../pipelines/cv/ocr_detection_pipeline.py | 94 ++++++++++--------- .../model_resnet_mutex_v4_linewithchar.py | 6 +- .../pipelines/cv/ocr_utils/resnet18_v1.py | 6 +- .../pipelines/cv/ocr_utils/resnet_utils.py | 6 +- tests/pipelines/test_ocr_detection.py | 5 + 5 files changed, 72 insertions(+), 45 deletions(-) diff --git a/modelscope/pipelines/cv/ocr_detection_pipeline.py b/modelscope/pipelines/cv/ocr_detection_pipeline.py index 0502fe36..4856b06b 100644 --- a/modelscope/pipelines/cv/ocr_detection_pipeline.py +++ b/modelscope/pipelines/cv/ocr_detection_pipeline.py @@ -8,7 +8,6 @@ import cv2 import numpy as np import PIL import tensorflow as tf -import tf_slim as slim from modelscope.metainfo import Pipelines from modelscope.pipelines.base import Input @@ -19,6 +18,11 @@ from ..base import Pipeline from ..builder import PIPELINES from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils +if tf.__version__ >= '2.0': + import tf_slim as slim +else: + from tensorflow.contrib import slim + if tf.__version__ >= '2.0': tf = tf.compat.v1 tf.compat.v1.disable_eager_execution() @@ -44,6 +48,7 @@ class OCRDetectionPipeline(Pipeline): def __init__(self, model: str): super().__init__(model=model) + tf.reset_default_graph() model_path = osp.join( osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER), 'checkpoint-80000') @@ -51,51 +56,56 @@ class OCRDetectionPipeline(Pipeline): config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True self._session = tf.Session(config=config) - global_step = tf.get_variable( - 'global_step', [], - initializer=tf.constant_initializer(0), - dtype=tf.int64, - trainable=False) - variable_averages = tf.train.ExponentialMovingAverage( - 0.997, global_step) self.input_images = tf.placeholder( tf.float32, shape=[1, 1024, 1024, 3], name='input_images') self.output = {} - # detector - detector = model_resnet_mutex_v4_linewithchar.SegLinkDetector() - all_maps = detector.build_model(self.input_images, is_training=False) - - # decode local predictions - all_nodes, all_links, all_reg = [], [], [] - for i, maps in enumerate(all_maps): - cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2] - reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) - - cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) - - lnk_prob_pos = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 4])[:, :2]) - lnk_prob_mut = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 4])[:, 2:]) - lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1) - - all_nodes.append(cls_prob) - all_links.append(lnk_prob) - all_reg.append(reg_maps) - - # decode segments and links - image_size = tf.shape(self.input_images)[1:3] - segments, group_indices, segment_counts, _ = ops.decode_segments_links_python( - image_size, - all_nodes, - all_links, - all_reg, - anchor_sizes=list(detector.anchor_sizes)) - - # combine segments - combined_rboxes, combined_counts = ops.combine_segments_python( - segments, group_indices, segment_counts) - self.output['combined_rboxes'] = combined_rboxes - self.output['combined_counts'] = combined_counts + with tf.variable_scope('', reuse=tf.AUTO_REUSE): + global_step = tf.get_variable( + 'global_step', [], + initializer=tf.constant_initializer(0), + dtype=tf.int64, + trainable=False) + variable_averages = tf.train.ExponentialMovingAverage( + 0.997, global_step) + + # detector + detector = model_resnet_mutex_v4_linewithchar.SegLinkDetector() + all_maps = detector.build_model( + self.input_images, is_training=False) + + # decode local predictions + all_nodes, all_links, all_reg = [], [], [] + for i, maps in enumerate(all_maps): + cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2] + reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) + + cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) + + lnk_prob_pos = tf.nn.softmax( + tf.reshape(lnk_maps, [-1, 4])[:, :2]) + lnk_prob_mut = tf.nn.softmax( + tf.reshape(lnk_maps, [-1, 4])[:, 2:]) + lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1) + + all_nodes.append(cls_prob) + all_links.append(lnk_prob) + all_reg.append(reg_maps) + + # decode segments and links + image_size = tf.shape(self.input_images)[1:3] + segments, group_indices, segment_counts, _ = ops.decode_segments_links_python( + image_size, + all_nodes, + all_links, + all_reg, + anchor_sizes=list(detector.anchor_sizes)) + + # combine segments + combined_rboxes, combined_counts = ops.combine_segments_python( + segments, group_indices, segment_counts) + self.output['combined_rboxes'] = combined_rboxes + self.output['combined_counts'] = combined_counts with self._session.as_default() as sess: logger.info(f'loading model from {model_path}') diff --git a/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py b/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py index 50b8ba02..d03ff405 100644 --- a/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py +++ b/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py @@ -1,8 +1,12 @@ import tensorflow as tf -import tf_slim as slim from . import ops, resnet18_v1, resnet_utils +if tf.__version__ >= '2.0': + import tf_slim as slim +else: + from tensorflow.contrib import slim + if tf.__version__ >= '2.0': tf = tf.compat.v1 diff --git a/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py b/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py index 6371d4e5..7930c5a3 100644 --- a/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py +++ b/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py @@ -30,10 +30,14 @@ ResNet-101 for semantic segmentation into 21 classes: output_stride=16) """ import tensorflow as tf -import tf_slim as slim from . import resnet_utils +if tf.__version__ >= '2.0': + import tf_slim as slim +else: + from tensorflow.contrib import slim + if tf.__version__ >= '2.0': tf = tf.compat.v1 diff --git a/modelscope/pipelines/cv/ocr_utils/resnet_utils.py b/modelscope/pipelines/cv/ocr_utils/resnet_utils.py index e0e240c8..0a9af224 100644 --- a/modelscope/pipelines/cv/ocr_utils/resnet_utils.py +++ b/modelscope/pipelines/cv/ocr_utils/resnet_utils.py @@ -19,7 +19,11 @@ implementation is more memory efficient. import collections import tensorflow as tf -import tf_slim as slim + +if tf.__version__ >= '2.0': + import tf_slim as slim +else: + from tensorflow.contrib import slim if tf.__version__ >= '2.0': tf = tf.compat.v1 diff --git a/tests/pipelines/test_ocr_detection.py b/tests/pipelines/test_ocr_detection.py index 986961b7..d1ecd4e4 100644 --- a/tests/pipelines/test_ocr_detection.py +++ b/tests/pipelines/test_ocr_detection.py @@ -27,6 +27,11 @@ class OCRDetectionTest(unittest.TestCase): print('ocr detection results: ') print(result) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + ocr_detection = pipeline(Tasks.ocr_detection, model=self.model_id) + self.pipeline_inference(ocr_detection, self.test_image) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_modelhub_default_model(self): ocr_detection = pipeline(Tasks.ocr_detection)