| @@ -5,13 +5,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved. | |||
| #------------------------------------------------------------------- | |||
| """Class to convert function to graph""" | |||
| import os | |||
| import sys | |||
| import getopt | |||
| from google.protobuf import text_format | |||
| import graph_library_pb2 | |||
| import tensorflow as tf | |||
| from tensorflow.python.framework import function_def_to_graph | |||
| from tensorflow.python.framework.errors_impl import NotFoundError | |||
| from tensorflow.python.platform import gfile | |||
| @@ -20,34 +21,30 @@ from tensorflow.core.framework import tensor_shape_pb2 | |||
| from tensorflow.core.framework import types_pb2 | |||
| from tensorflow.core.framework import versions_pb2 | |||
| from tensorflow.python.eager import context | |||
| from tensorflow.python.framework import importer | |||
| from tensorflow.python.framework import ops | |||
| from tensorflow.python.framework import versions | |||
| sys.path.append(os.path.join(os.path.split(os.path.realpath(__file__))[0], "util")) | |||
| import graph_library_pb2 | |||
| def _get_num_args(arg_def, node_def): | |||
| if arg_def.number_attr: | |||
| return node_def.attr[arg_def.number_attr].i | |||
| elif arg_def.type_list_attr: | |||
| if arg_def.type_list_attr: | |||
| return len(node_def.attr[arg_def.type_list_attr].list.type) | |||
| elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID: | |||
| if arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID: | |||
| return 1 | |||
| else: | |||
| raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def))) | |||
| raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def))) | |||
| def is_function(fname): | |||
| """Checks for a function definition with `fname` in the current context.""" | |||
| if context.executing_eagerly(): | |||
| return context.context().has_function(fname) | |||
| else: | |||
| return ops.get_default_graph()._is_function(fname) | |||
| return ops.get_default_graph()._is_function(fname) | |||
| def create_arg_for_input_nodes(fdef, graph_def, input_shapes): | |||
| """Create arg for input nodes.""" | |||
| for i, arg_def in enumerate(fdef.signature.input_arg): | |||
| node_def = graph_def.node.add() | |||
| node_def.name = arg_def.name | |||
| @@ -65,9 +62,10 @@ def create_arg_for_input_nodes(fdef, graph_def, input_shapes): | |||
| # applied to these Arg nodes. | |||
| if k.startswith("_"): | |||
| node_def.attr[k].CopyFrom(arg_attrs[k]) | |||
| return | |||
| def create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name): | |||
| """Create retval for output nodes.""" | |||
| for i, arg_def in enumerate(fdef.signature.output_arg): | |||
| node_def = graph_def.node.add() | |||
| node_def.name = '{}_Retval'.format(arg_def.name) | |||
| @@ -78,9 +76,10 @@ def create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name): | |||
| ret_name = fdef.ret[arg_def.name] | |||
| node_def.input.append(nested_to_flat_tensor_name[ret_name]) | |||
| return | |||
| def updat_input_index(node_def, op_def, nested_to_flat_tensor_name): | |||
| """Update input index.""" | |||
| flattened_index = 0 | |||
| for arg_def in op_def.output_arg: | |||
| num_args = _get_num_args(arg_def, node_def) | |||
| @@ -98,14 +97,16 @@ def updat_input_index(node_def, op_def, nested_to_flat_tensor_name): | |||
| nested_to_flat_tensor_name[control_name] = control_name | |||
| return | |||
| def build_tensor_name(fdef, default_graph): | |||
| """Build tensor name.""" | |||
| nested_to_flat_tensor_name = {} | |||
| for arg_def in fdef.signature.input_arg: | |||
| nested_to_flat_tensor_name[arg_def.name] = arg_def.name | |||
| control_name = '^{}'.format(arg_def.name) | |||
| nested_to_flat_tensor_name[control_name] = control_name | |||
| global op_def | |||
| op_def = None | |||
| for node_def in fdef.node_def: | |||
| f = default_graph._functions.get(node_def.op, None) | |||
| if f is not None and hasattr(f, "signature"): | |||
| @@ -136,7 +137,9 @@ def build_tensor_name(fdef, default_graph): | |||
| updat_input_index(node_def, op_def, nested_to_flat_tensor_name) | |||
| return nested_to_flat_tensor_name | |||
| def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True): | |||
| """Convert function def to graph def""" | |||
| graph_def = graph_pb2.GraphDef() | |||
| graph_def.versions.CopyFrom( | |||
| versions_pb2.VersionDef( | |||
| @@ -175,7 +178,7 @@ def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=Tr | |||
| # Update inputs of all nodes in graph. | |||
| for node_def in graph_def.node: | |||
| for i in range(len(node_def.input)): | |||
| node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]] | |||
| node_def.input[i] = nested_to_flat_tensor_name[node_def.input.get(i)] | |||
| # Create _Retval for output nodes. | |||
| create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name) | |||
| @@ -184,6 +187,7 @@ def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=Tr | |||
| def convert_graphs(filename): | |||
| """Convert graphs.""" | |||
| try: | |||
| with tf.io.gfile.GFile(filename, 'rb') as f: | |||
| graph_def = tf.compat.v1.GraphDef() | |||
| @@ -204,9 +208,10 @@ def convert_graphs(filename): | |||
| def convert_subgraphs(graph_def, filename): | |||
| """Convert sub graphs.""" | |||
| graph_def_library = graph_library_pb2.GraphDefLibrary() | |||
| for i, fdef in enumerate(graph_def.library.function): | |||
| sub_graph, nested_to_flat_tensor_name = convert_function_def_to_graph_def(fdef, copy_functions=False) | |||
| sub_graph, _ = convert_function_def_to_graph_def(fdef, copy_functions=False) | |||
| print("INFO: Convert FunctionDef, index:{}, name:{}".format(str(i), fdef.signature.name)) | |||
| sub_graph_name = '{}.pb'.format(fdef.signature.name) | |||
| result_path = '{}/results'.format(os.path.dirname(os.path.abspath(filename))) | |||
| @@ -229,6 +234,7 @@ def convert_subgraphs(graph_def, filename): | |||
| def usage(): | |||
| """Print the usage.""" | |||
| print( | |||
| ''' | |||
| Based on tensorflow 1.15 or later, Python 3 | |||
| @@ -265,5 +271,5 @@ if __name__ == '__main__': | |||
| break | |||
| except getopt.GetoptError: | |||
| print("ERROR: Input parameters is invalid, use '--help' to view the help.") | |||
| if (len(sys.argv) == 1): | |||
| if len(sys.argv) == 1: | |||
| print("INFO: Please specify the input parameters, and use '--help' to view the help.") | |||