Browse Source

[to #42322933]fix bug in demo_service Tensor is not an element of this graph

修复ocr_detection demo_service服务的bug
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10333119
master
xixing.tj yingda.chen 3 years ago
parent
commit
32002a290f
1 changed files with 77 additions and 70 deletions
  1. +77
    -70
      modelscope/pipelines/cv/ocr_detection_pipeline.py

+ 77
- 70
modelscope/pipelines/cv/ocr_detection_pipeline.py View File

@@ -56,68 +56,72 @@ class OCRDetectionPipeline(Pipeline):
model_path = osp.join( model_path = osp.join(
osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER), osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER),
'checkpoint-80000') 'checkpoint-80000')

with device_placement(self.framework, self.device_name):
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
self._session = tf.Session(config=config)
self.input_images = tf.placeholder(
tf.float32, shape=[1, 1024, 1024, 3], name='input_images')
self.output = {}

with tf.variable_scope('', reuse=tf.AUTO_REUSE):
global_step = tf.get_variable(
'global_step', [],
initializer=tf.constant_initializer(0),
dtype=tf.int64,
trainable=False)
variable_averages = tf.train.ExponentialMovingAverage(
0.997, global_step)

# detector
detector = SegLinkDetector()
all_maps = detector.build_model(
self.input_images, is_training=False)

# decode local predictions
all_nodes, all_links, all_reg = [], [], []
for i, maps in enumerate(all_maps):
cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2]
reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE)

cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2]))

lnk_prob_pos = tf.nn.softmax(
tf.reshape(lnk_maps, [-1, 4])[:, :2])
lnk_prob_mut = tf.nn.softmax(
tf.reshape(lnk_maps, [-1, 4])[:, 2:])
lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1)

all_nodes.append(cls_prob)
all_links.append(lnk_prob)
all_reg.append(reg_maps)

# decode segments and links
image_size = tf.shape(self.input_images)[1:3]
segments, group_indices, segment_counts, _ = decode_segments_links_python(
image_size,
all_nodes,
all_links,
all_reg,
anchor_sizes=list(detector.anchor_sizes))

# combine segments
combined_rboxes, combined_counts = combine_segments_python(
segments, group_indices, segment_counts)
self.output['combined_rboxes'] = combined_rboxes
self.output['combined_counts'] = combined_counts

with self._session.as_default() as sess:
logger.info(f'loading model from {model_path}')
# load model
model_loader = tf.train.Saver(
variable_averages.variables_to_restore())
model_loader.restore(sess, model_path)
self._graph = tf.get_default_graph()
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
self._session = tf.Session(config=config)

with self._graph.as_default():
with device_placement(self.framework, self.device_name):
self.input_images = tf.placeholder(
tf.float32, shape=[1, 1024, 1024, 3], name='input_images')
self.output = {}

with tf.variable_scope('', reuse=tf.AUTO_REUSE):
global_step = tf.get_variable(
'global_step', [],
initializer=tf.constant_initializer(0),
dtype=tf.int64,
trainable=False)
variable_averages = tf.train.ExponentialMovingAverage(
0.997, global_step)

# detector
detector = SegLinkDetector()
all_maps = detector.build_model(
self.input_images, is_training=False)

# decode local predictions
all_nodes, all_links, all_reg = [], [], []
for i, maps in enumerate(all_maps):
cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[
2]
reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE)

cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2]))

lnk_prob_pos = tf.nn.softmax(
tf.reshape(lnk_maps, [-1, 4])[:, :2])
lnk_prob_mut = tf.nn.softmax(
tf.reshape(lnk_maps, [-1, 4])[:, 2:])
lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut],
axis=1)

all_nodes.append(cls_prob)
all_links.append(lnk_prob)
all_reg.append(reg_maps)

# decode segments and links
image_size = tf.shape(self.input_images)[1:3]
segments, group_indices, segment_counts, _ = decode_segments_links_python(
image_size,
all_nodes,
all_links,
all_reg,
anchor_sizes=list(detector.anchor_sizes))

# combine segments
combined_rboxes, combined_counts = combine_segments_python(
segments, group_indices, segment_counts)
self.output['combined_rboxes'] = combined_rboxes
self.output['combined_counts'] = combined_counts

with self._session.as_default() as sess:
logger.info(f'loading model from {model_path}')
# load model
model_loader = tf.train.Saver(
variable_averages.variables_to_restore())
model_loader.restore(sess, model_path)


def preprocess(self, input: Input) -> Dict[str, Any]: def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input) img = LoadImage.convert_to_ndarray(input)
@@ -132,19 +136,22 @@ class OCRDetectionPipeline(Pipeline):
img_pad_resize = img_pad_resize - np.array([123.68, 116.78, 103.94], img_pad_resize = img_pad_resize - np.array([123.68, 116.78, 103.94],
dtype=np.float32) dtype=np.float32)


resize_size = tf.stack([resize_size, resize_size])
orig_size = tf.stack([max(h, w), max(h, w)])
self.output['orig_size'] = orig_size
self.output['resize_size'] = resize_size
with self._graph.as_default():
resize_size = tf.stack([resize_size, resize_size])
orig_size = tf.stack([max(h, w), max(h, w)])
self.output['orig_size'] = orig_size
self.output['resize_size'] = resize_size


result = {'img': np.expand_dims(img_pad_resize, axis=0)} result = {'img': np.expand_dims(img_pad_resize, axis=0)}
return result return result


def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
with self._session.as_default():
feed_dict = {self.input_images: input['img']}
sess_outputs = self._session.run(self.output, feed_dict=feed_dict)
return sess_outputs
with self._graph.as_default():
with self._session.as_default():
feed_dict = {self.input_images: input['img']}
sess_outputs = self._session.run(
self.output, feed_dict=feed_dict)
return sess_outputs


def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
rboxes = inputs['combined_rboxes'][0] rboxes = inputs['combined_rboxes'][0]


Loading…
Cancel
Save