| @@ -5,13 +5,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved. | # Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved. | ||||
| #------------------------------------------------------------------- | #------------------------------------------------------------------- | ||||
| """Class to convert function to graph""" | |||||
| import os | import os | ||||
| import sys | import sys | ||||
| import getopt | import getopt | ||||
| from google.protobuf import text_format | |||||
| import graph_library_pb2 | |||||
| import tensorflow as tf | import tensorflow as tf | ||||
| from tensorflow.python.framework import function_def_to_graph | |||||
| from tensorflow.python.framework.errors_impl import NotFoundError | from tensorflow.python.framework.errors_impl import NotFoundError | ||||
| from tensorflow.python.platform import gfile | from tensorflow.python.platform import gfile | ||||
| @@ -20,34 +21,30 @@ from tensorflow.core.framework import tensor_shape_pb2 | |||||
| from tensorflow.core.framework import types_pb2 | from tensorflow.core.framework import types_pb2 | ||||
| from tensorflow.core.framework import versions_pb2 | from tensorflow.core.framework import versions_pb2 | ||||
| from tensorflow.python.eager import context | from tensorflow.python.eager import context | ||||
| from tensorflow.python.framework import importer | |||||
| from tensorflow.python.framework import ops | from tensorflow.python.framework import ops | ||||
| from tensorflow.python.framework import versions | from tensorflow.python.framework import versions | ||||
| sys.path.append(os.path.join(os.path.split(os.path.realpath(__file__))[0], "util")) | sys.path.append(os.path.join(os.path.split(os.path.realpath(__file__))[0], "util")) | ||||
| import graph_library_pb2 | |||||
| def _get_num_args(arg_def, node_def): | def _get_num_args(arg_def, node_def): | ||||
| if arg_def.number_attr: | if arg_def.number_attr: | ||||
| return node_def.attr[arg_def.number_attr].i | return node_def.attr[arg_def.number_attr].i | ||||
| elif arg_def.type_list_attr: | |||||
| if arg_def.type_list_attr: | |||||
| return len(node_def.attr[arg_def.type_list_attr].list.type) | return len(node_def.attr[arg_def.type_list_attr].list.type) | ||||
| elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID: | |||||
| if arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID: | |||||
| return 1 | return 1 | ||||
| else: | |||||
| raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def))) | |||||
| raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def))) | |||||
| def is_function(fname): | def is_function(fname): | ||||
| """Checks for a function definition with `fname` in the current context.""" | """Checks for a function definition with `fname` in the current context.""" | ||||
| if context.executing_eagerly(): | if context.executing_eagerly(): | ||||
| return context.context().has_function(fname) | return context.context().has_function(fname) | ||||
| else: | |||||
| return ops.get_default_graph()._is_function(fname) | |||||
| return ops.get_default_graph()._is_function(fname) | |||||
| def create_arg_for_input_nodes(fdef, graph_def, input_shapes): | def create_arg_for_input_nodes(fdef, graph_def, input_shapes): | ||||
| """Create arg for input nodes.""" | |||||
| for i, arg_def in enumerate(fdef.signature.input_arg): | for i, arg_def in enumerate(fdef.signature.input_arg): | ||||
| node_def = graph_def.node.add() | node_def = graph_def.node.add() | ||||
| node_def.name = arg_def.name | node_def.name = arg_def.name | ||||
| @@ -65,9 +62,10 @@ def create_arg_for_input_nodes(fdef, graph_def, input_shapes): | |||||
| # applied to these Arg nodes. | # applied to these Arg nodes. | ||||
| if k.startswith("_"): | if k.startswith("_"): | ||||
| node_def.attr[k].CopyFrom(arg_attrs[k]) | node_def.attr[k].CopyFrom(arg_attrs[k]) | ||||
| return | |||||
| def create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name): | def create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name): | ||||
| """Create retval for output nodes.""" | |||||
| for i, arg_def in enumerate(fdef.signature.output_arg): | for i, arg_def in enumerate(fdef.signature.output_arg): | ||||
| node_def = graph_def.node.add() | node_def = graph_def.node.add() | ||||
| node_def.name = '{}_Retval'.format(arg_def.name) | node_def.name = '{}_Retval'.format(arg_def.name) | ||||
| @@ -78,9 +76,10 @@ def create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name): | |||||
| ret_name = fdef.ret[arg_def.name] | ret_name = fdef.ret[arg_def.name] | ||||
| node_def.input.append(nested_to_flat_tensor_name[ret_name]) | node_def.input.append(nested_to_flat_tensor_name[ret_name]) | ||||
| return | |||||
| def updat_input_index(node_def, op_def, nested_to_flat_tensor_name): | def updat_input_index(node_def, op_def, nested_to_flat_tensor_name): | ||||
| """Update input index.""" | |||||
| flattened_index = 0 | flattened_index = 0 | ||||
| for arg_def in op_def.output_arg: | for arg_def in op_def.output_arg: | ||||
| num_args = _get_num_args(arg_def, node_def) | num_args = _get_num_args(arg_def, node_def) | ||||
| @@ -98,14 +97,16 @@ def updat_input_index(node_def, op_def, nested_to_flat_tensor_name): | |||||
| nested_to_flat_tensor_name[control_name] = control_name | nested_to_flat_tensor_name[control_name] = control_name | ||||
| return | return | ||||
| def build_tensor_name(fdef, default_graph): | def build_tensor_name(fdef, default_graph): | ||||
| """Build tensor name.""" | |||||
| nested_to_flat_tensor_name = {} | nested_to_flat_tensor_name = {} | ||||
| for arg_def in fdef.signature.input_arg: | for arg_def in fdef.signature.input_arg: | ||||
| nested_to_flat_tensor_name[arg_def.name] = arg_def.name | nested_to_flat_tensor_name[arg_def.name] = arg_def.name | ||||
| control_name = '^{}'.format(arg_def.name) | control_name = '^{}'.format(arg_def.name) | ||||
| nested_to_flat_tensor_name[control_name] = control_name | nested_to_flat_tensor_name[control_name] = control_name | ||||
| global op_def | |||||
| op_def = None | |||||
| for node_def in fdef.node_def: | for node_def in fdef.node_def: | ||||
| f = default_graph._functions.get(node_def.op, None) | f = default_graph._functions.get(node_def.op, None) | ||||
| if f is not None and hasattr(f, "signature"): | if f is not None and hasattr(f, "signature"): | ||||
| @@ -136,7 +137,9 @@ def build_tensor_name(fdef, default_graph): | |||||
| updat_input_index(node_def, op_def, nested_to_flat_tensor_name) | updat_input_index(node_def, op_def, nested_to_flat_tensor_name) | ||||
| return nested_to_flat_tensor_name | return nested_to_flat_tensor_name | ||||
| def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True): | def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True): | ||||
| """Convert function def to graph def""" | |||||
| graph_def = graph_pb2.GraphDef() | graph_def = graph_pb2.GraphDef() | ||||
| graph_def.versions.CopyFrom( | graph_def.versions.CopyFrom( | ||||
| versions_pb2.VersionDef( | versions_pb2.VersionDef( | ||||
| @@ -175,7 +178,7 @@ def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=Tr | |||||
| # Update inputs of all nodes in graph. | # Update inputs of all nodes in graph. | ||||
| for node_def in graph_def.node: | for node_def in graph_def.node: | ||||
| for i in range(len(node_def.input)): | for i in range(len(node_def.input)): | ||||
| node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]] | |||||
| node_def.input[i] = nested_to_flat_tensor_name[node_def.input.get(i)] | |||||
| # Create _Retval for output nodes. | # Create _Retval for output nodes. | ||||
| create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name) | create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name) | ||||
| @@ -184,6 +187,7 @@ def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=Tr | |||||
| def convert_graphs(filename): | def convert_graphs(filename): | ||||
| """Convert graphs.""" | |||||
| try: | try: | ||||
| with tf.io.gfile.GFile(filename, 'rb') as f: | with tf.io.gfile.GFile(filename, 'rb') as f: | ||||
| graph_def = tf.compat.v1.GraphDef() | graph_def = tf.compat.v1.GraphDef() | ||||
| @@ -204,9 +208,10 @@ def convert_graphs(filename): | |||||
| def convert_subgraphs(graph_def, filename): | def convert_subgraphs(graph_def, filename): | ||||
| """Convert sub graphs.""" | |||||
| graph_def_library = graph_library_pb2.GraphDefLibrary() | graph_def_library = graph_library_pb2.GraphDefLibrary() | ||||
| for i, fdef in enumerate(graph_def.library.function): | for i, fdef in enumerate(graph_def.library.function): | ||||
| sub_graph, nested_to_flat_tensor_name = convert_function_def_to_graph_def(fdef, copy_functions=False) | |||||
| sub_graph, _ = convert_function_def_to_graph_def(fdef, copy_functions=False) | |||||
| print("INFO: Convert FunctionDef, index:{}, name:{}".format(str(i), fdef.signature.name)) | print("INFO: Convert FunctionDef, index:{}, name:{}".format(str(i), fdef.signature.name)) | ||||
| sub_graph_name = '{}.pb'.format(fdef.signature.name) | sub_graph_name = '{}.pb'.format(fdef.signature.name) | ||||
| result_path = '{}/results'.format(os.path.dirname(os.path.abspath(filename))) | result_path = '{}/results'.format(os.path.dirname(os.path.abspath(filename))) | ||||
| @@ -229,6 +234,7 @@ def convert_subgraphs(graph_def, filename): | |||||
| def usage(): | def usage(): | ||||
| """Print the usage.""" | |||||
| print( | print( | ||||
| ''' | ''' | ||||
| Based on tensorflow 1.15 or later, Python 3 | Based on tensorflow 1.15 or later, Python 3 | ||||
| @@ -265,5 +271,5 @@ if __name__ == '__main__': | |||||
| break | break | ||||
| except getopt.GetoptError: | except getopt.GetoptError: | ||||
| print("ERROR: Input parameters is invalid, use '--help' to view the help.") | print("ERROR: Input parameters is invalid, use '--help' to view the help.") | ||||
| if (len(sys.argv) == 1): | |||||
| if len(sys.argv) == 1: | |||||
| print("INFO: Please specify the input parameters, and use '--help' to view the help.") | print("INFO: Please specify the input parameters, and use '--help' to view the help.") | ||||