diff --git a/parser/func_to_graph/func2graph.py b/parser/func_to_graph/func2graph.py index 633440f..4c7c085 100644 --- a/parser/func_to_graph/func2graph.py +++ b/parser/func_to_graph/func2graph.py @@ -5,13 +5,14 @@ # Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved. #------------------------------------------------------------------- +"""Class to convert function to graph""" + import os import sys import getopt -from google.protobuf import text_format +import graph_library_pb2 import tensorflow as tf -from tensorflow.python.framework import function_def_to_graph from tensorflow.python.framework.errors_impl import NotFoundError from tensorflow.python.platform import gfile @@ -20,34 +21,30 @@ from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.framework import versions_pb2 from tensorflow.python.eager import context -from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import versions sys.path.append(os.path.join(os.path.split(os.path.realpath(__file__))[0], "util")) -import graph_library_pb2 - - def _get_num_args(arg_def, node_def): if arg_def.number_attr: return node_def.attr[arg_def.number_attr].i - elif arg_def.type_list_attr: + if arg_def.type_list_attr: return len(node_def.attr[arg_def.type_list_attr].list.type) - elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID: + if arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID: return 1 - else: - raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def))) + raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def))) def is_function(fname): """Checks for a function definition with `fname` in the current context.""" if context.executing_eagerly(): return context.context().has_function(fname) - else: - return ops.get_default_graph()._is_function(fname) + return ops.get_default_graph()._is_function(fname) + def create_arg_for_input_nodes(fdef, graph_def, input_shapes): + """Create arg for input nodes.""" for i, arg_def in enumerate(fdef.signature.input_arg): node_def = graph_def.node.add() node_def.name = arg_def.name @@ -65,9 +62,10 @@ def create_arg_for_input_nodes(fdef, graph_def, input_shapes): # applied to these Arg nodes. if k.startswith("_"): node_def.attr[k].CopyFrom(arg_attrs[k]) - return + def create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name): + """Create retval for output nodes.""" for i, arg_def in enumerate(fdef.signature.output_arg): node_def = graph_def.node.add() node_def.name = '{}_Retval'.format(arg_def.name) @@ -78,9 +76,10 @@ def create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name): ret_name = fdef.ret[arg_def.name] node_def.input.append(nested_to_flat_tensor_name[ret_name]) - return + def updat_input_index(node_def, op_def, nested_to_flat_tensor_name): + """Update input index.""" flattened_index = 0 for arg_def in op_def.output_arg: num_args = _get_num_args(arg_def, node_def) @@ -98,14 +97,16 @@ def updat_input_index(node_def, op_def, nested_to_flat_tensor_name): nested_to_flat_tensor_name[control_name] = control_name return + def build_tensor_name(fdef, default_graph): + """Build tensor name.""" nested_to_flat_tensor_name = {} for arg_def in fdef.signature.input_arg: nested_to_flat_tensor_name[arg_def.name] = arg_def.name control_name = '^{}'.format(arg_def.name) nested_to_flat_tensor_name[control_name] = control_name - global op_def + op_def = None for node_def in fdef.node_def: f = default_graph._functions.get(node_def.op, None) if f is not None and hasattr(f, "signature"): @@ -136,7 +137,9 @@ def build_tensor_name(fdef, default_graph): updat_input_index(node_def, op_def, nested_to_flat_tensor_name) return nested_to_flat_tensor_name + def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True): + """Convert function def to graph def""" graph_def = graph_pb2.GraphDef() graph_def.versions.CopyFrom( versions_pb2.VersionDef( @@ -175,7 +178,7 @@ def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=Tr # Update inputs of all nodes in graph. for node_def in graph_def.node: for i in range(len(node_def.input)): - node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]] + node_def.input[i] = nested_to_flat_tensor_name[node_def.input.get(i)] # Create _Retval for output nodes. create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name) @@ -184,6 +187,7 @@ def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=Tr def convert_graphs(filename): + """Convert graphs.""" try: with tf.io.gfile.GFile(filename, 'rb') as f: graph_def = tf.compat.v1.GraphDef() @@ -204,9 +208,10 @@ def convert_graphs(filename): def convert_subgraphs(graph_def, filename): + """Convert sub graphs.""" graph_def_library = graph_library_pb2.GraphDefLibrary() for i, fdef in enumerate(graph_def.library.function): - sub_graph, nested_to_flat_tensor_name = convert_function_def_to_graph_def(fdef, copy_functions=False) + sub_graph, _ = convert_function_def_to_graph_def(fdef, copy_functions=False) print("INFO: Convert FunctionDef, index:{}, name:{}".format(str(i), fdef.signature.name)) sub_graph_name = '{}.pb'.format(fdef.signature.name) result_path = '{}/results'.format(os.path.dirname(os.path.abspath(filename))) @@ -229,6 +234,7 @@ def convert_subgraphs(graph_def, filename): def usage(): + """Print the usage.""" print( ''' Based on tensorflow 1.15 or later, Python 3 @@ -265,5 +271,5 @@ if __name__ == '__main__': break except getopt.GetoptError: print("ERROR: Input parameters is invalid, use '--help' to view the help.") - if (len(sys.argv) == 1): + if len(sys.argv) == 1: print("INFO: Please specify the input parameters, and use '--help' to view the help.")