diff --git a/modelscope/pipelines/cv/ocr_detection_pipeline.py b/modelscope/pipelines/cv/ocr_detection_pipeline.py index b73f65a4..07231efa 100644 --- a/modelscope/pipelines/cv/ocr_detection_pipeline.py +++ b/modelscope/pipelines/cv/ocr_detection_pipeline.py @@ -56,68 +56,72 @@ class OCRDetectionPipeline(Pipeline): model_path = osp.join( osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER), 'checkpoint-80000') - - with device_placement(self.framework, self.device_name): - config = tf.ConfigProto(allow_soft_placement=True) - config.gpu_options.allow_growth = True - self._session = tf.Session(config=config) - self.input_images = tf.placeholder( - tf.float32, shape=[1, 1024, 1024, 3], name='input_images') - self.output = {} - - with tf.variable_scope('', reuse=tf.AUTO_REUSE): - global_step = tf.get_variable( - 'global_step', [], - initializer=tf.constant_initializer(0), - dtype=tf.int64, - trainable=False) - variable_averages = tf.train.ExponentialMovingAverage( - 0.997, global_step) - - # detector - detector = SegLinkDetector() - all_maps = detector.build_model( - self.input_images, is_training=False) - - # decode local predictions - all_nodes, all_links, all_reg = [], [], [] - for i, maps in enumerate(all_maps): - cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2] - reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) - - cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) - - lnk_prob_pos = tf.nn.softmax( - tf.reshape(lnk_maps, [-1, 4])[:, :2]) - lnk_prob_mut = tf.nn.softmax( - tf.reshape(lnk_maps, [-1, 4])[:, 2:]) - lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1) - - all_nodes.append(cls_prob) - all_links.append(lnk_prob) - all_reg.append(reg_maps) - - # decode segments and links - image_size = tf.shape(self.input_images)[1:3] - segments, group_indices, segment_counts, _ = decode_segments_links_python( - image_size, - all_nodes, - all_links, - all_reg, - anchor_sizes=list(detector.anchor_sizes)) - - # combine segments - combined_rboxes, combined_counts = combine_segments_python( - segments, group_indices, segment_counts) - self.output['combined_rboxes'] = combined_rboxes - self.output['combined_counts'] = combined_counts - - with self._session.as_default() as sess: - logger.info(f'loading model from {model_path}') - # load model - model_loader = tf.train.Saver( - variable_averages.variables_to_restore()) - model_loader.restore(sess, model_path) + self._graph = tf.get_default_graph() + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + self._session = tf.Session(config=config) + + with self._graph.as_default(): + with device_placement(self.framework, self.device_name): + self.input_images = tf.placeholder( + tf.float32, shape=[1, 1024, 1024, 3], name='input_images') + self.output = {} + + with tf.variable_scope('', reuse=tf.AUTO_REUSE): + global_step = tf.get_variable( + 'global_step', [], + initializer=tf.constant_initializer(0), + dtype=tf.int64, + trainable=False) + variable_averages = tf.train.ExponentialMovingAverage( + 0.997, global_step) + + # detector + detector = SegLinkDetector() + all_maps = detector.build_model( + self.input_images, is_training=False) + + # decode local predictions + all_nodes, all_links, all_reg = [], [], [] + for i, maps in enumerate(all_maps): + cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[ + 2] + reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) + + cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) + + lnk_prob_pos = tf.nn.softmax( + tf.reshape(lnk_maps, [-1, 4])[:, :2]) + lnk_prob_mut = tf.nn.softmax( + tf.reshape(lnk_maps, [-1, 4])[:, 2:]) + lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], + axis=1) + + all_nodes.append(cls_prob) + all_links.append(lnk_prob) + all_reg.append(reg_maps) + + # decode segments and links + image_size = tf.shape(self.input_images)[1:3] + segments, group_indices, segment_counts, _ = decode_segments_links_python( + image_size, + all_nodes, + all_links, + all_reg, + anchor_sizes=list(detector.anchor_sizes)) + + # combine segments + combined_rboxes, combined_counts = combine_segments_python( + segments, group_indices, segment_counts) + self.output['combined_rboxes'] = combined_rboxes + self.output['combined_counts'] = combined_counts + + with self._session.as_default() as sess: + logger.info(f'loading model from {model_path}') + # load model + model_loader = tf.train.Saver( + variable_averages.variables_to_restore()) + model_loader.restore(sess, model_path) def preprocess(self, input: Input) -> Dict[str, Any]: img = LoadImage.convert_to_ndarray(input) @@ -132,19 +136,22 @@ class OCRDetectionPipeline(Pipeline): img_pad_resize = img_pad_resize - np.array([123.68, 116.78, 103.94], dtype=np.float32) - resize_size = tf.stack([resize_size, resize_size]) - orig_size = tf.stack([max(h, w), max(h, w)]) - self.output['orig_size'] = orig_size - self.output['resize_size'] = resize_size + with self._graph.as_default(): + resize_size = tf.stack([resize_size, resize_size]) + orig_size = tf.stack([max(h, w), max(h, w)]) + self.output['orig_size'] = orig_size + self.output['resize_size'] = resize_size result = {'img': np.expand_dims(img_pad_resize, axis=0)} return result def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - with self._session.as_default(): - feed_dict = {self.input_images: input['img']} - sess_outputs = self._session.run(self.output, feed_dict=feed_dict) - return sess_outputs + with self._graph.as_default(): + with self._session.as_default(): + feed_dict = {self.input_images: input['img']} + sess_outputs = self._session.run( + self.output, feed_dict=feed_dict) + return sess_outputs def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: rboxes = inputs['combined_rboxes'][0]