Browse Source

fix con

pull/451/head
yu-jiaoliang 4 years ago
parent
commit
a046d31ad8
1 changed files with 24 additions and 18 deletions
  1. +24
    -18
      parser/func_to_graph/func2graph.py

+ 24
- 18
parser/func_to_graph/func2graph.py View File

@@ -5,13 +5,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved.
#-------------------------------------------------------------------

"""Class to convert function to graph"""

import os
import sys
import getopt

from google.protobuf import text_format
import graph_library_pb2
import tensorflow as tf
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.framework.errors_impl import NotFoundError
from tensorflow.python.platform import gfile

@@ -20,34 +21,30 @@ from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.framework import versions_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import versions

sys.path.append(os.path.join(os.path.split(os.path.realpath(__file__))[0], "util"))

import graph_library_pb2


def _get_num_args(arg_def, node_def):
if arg_def.number_attr:
return node_def.attr[arg_def.number_attr].i
elif arg_def.type_list_attr:
if arg_def.type_list_attr:
return len(node_def.attr[arg_def.type_list_attr].list.type)
elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID:
if arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID:
return 1
else:
raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def)))
raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def)))


def is_function(fname):
"""Checks for a function definition with `fname` in the current context."""
if context.executing_eagerly():
return context.context().has_function(fname)
else:
return ops.get_default_graph()._is_function(fname)
return ops.get_default_graph()._is_function(fname)

def create_arg_for_input_nodes(fdef, graph_def, input_shapes):
"""Create arg for input nodes."""
for i, arg_def in enumerate(fdef.signature.input_arg):
node_def = graph_def.node.add()
node_def.name = arg_def.name
@@ -65,9 +62,10 @@ def create_arg_for_input_nodes(fdef, graph_def, input_shapes):
# applied to these Arg nodes.
if k.startswith("_"):
node_def.attr[k].CopyFrom(arg_attrs[k])
return

def create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name):
"""Create retval for output nodes."""
for i, arg_def in enumerate(fdef.signature.output_arg):
node_def = graph_def.node.add()
node_def.name = '{}_Retval'.format(arg_def.name)
@@ -78,9 +76,10 @@ def create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name):

ret_name = fdef.ret[arg_def.name]
node_def.input.append(nested_to_flat_tensor_name[ret_name])
return

def updat_input_index(node_def, op_def, nested_to_flat_tensor_name):
"""Update input index."""
flattened_index = 0
for arg_def in op_def.output_arg:
num_args = _get_num_args(arg_def, node_def)
@@ -98,14 +97,16 @@ def updat_input_index(node_def, op_def, nested_to_flat_tensor_name):
nested_to_flat_tensor_name[control_name] = control_name
return


def build_tensor_name(fdef, default_graph):
"""Build tensor name."""
nested_to_flat_tensor_name = {}
for arg_def in fdef.signature.input_arg:
nested_to_flat_tensor_name[arg_def.name] = arg_def.name
control_name = '^{}'.format(arg_def.name)
nested_to_flat_tensor_name[control_name] = control_name

global op_def
op_def = None
for node_def in fdef.node_def:
f = default_graph._functions.get(node_def.op, None)
if f is not None and hasattr(f, "signature"):
@@ -136,7 +137,9 @@ def build_tensor_name(fdef, default_graph):
updat_input_index(node_def, op_def, nested_to_flat_tensor_name)
return nested_to_flat_tensor_name


def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True):
"""Convert function def to graph def"""
graph_def = graph_pb2.GraphDef()
graph_def.versions.CopyFrom(
versions_pb2.VersionDef(
@@ -175,7 +178,7 @@ def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=Tr
# Update inputs of all nodes in graph.
for node_def in graph_def.node:
for i in range(len(node_def.input)):
node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]
node_def.input[i] = nested_to_flat_tensor_name[node_def.input.get(i)]

# Create _Retval for output nodes.
create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name)
@@ -184,6 +187,7 @@ def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=Tr


def convert_graphs(filename):
"""Convert graphs."""
try:
with tf.io.gfile.GFile(filename, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
@@ -204,9 +208,10 @@ def convert_graphs(filename):


def convert_subgraphs(graph_def, filename):
"""Convert sub graphs."""
graph_def_library = graph_library_pb2.GraphDefLibrary()
for i, fdef in enumerate(graph_def.library.function):
sub_graph, nested_to_flat_tensor_name = convert_function_def_to_graph_def(fdef, copy_functions=False)
sub_graph, _ = convert_function_def_to_graph_def(fdef, copy_functions=False)
print("INFO: Convert FunctionDef, index:{}, name:{}".format(str(i), fdef.signature.name))
sub_graph_name = '{}.pb'.format(fdef.signature.name)
result_path = '{}/results'.format(os.path.dirname(os.path.abspath(filename)))
@@ -229,6 +234,7 @@ def convert_subgraphs(graph_def, filename):


def usage():
"""Print the usage."""
print(
'''
Based on tensorflow 1.15 or later, Python 3
@@ -265,5 +271,5 @@ if __name__ == '__main__':
break
except getopt.GetoptError:
print("ERROR: Input parameters is invalid, use '--help' to view the help.")
if (len(sys.argv) == 1):
if len(sys.argv) == 1:
print("INFO: Please specify the input parameters, and use '--help' to view the help.")

Loading…
Cancel
Save