|
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- #-------------------------------------------------------------------
- # Purpose:
- # Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved.
- #-------------------------------------------------------------------
-
- """Class to convert function to graph"""
-
- import os
- import sys
- import getopt
-
- from google.protobuf import text_format
- import tensorflow as tf
- from tensorflow.python.framework.errors_impl import NotFoundError
- from tensorflow.python.platform import gfile
-
- from tensorflow.core.framework import graph_pb2
- from tensorflow.core.framework import tensor_shape_pb2
- from tensorflow.core.framework import types_pb2
- from tensorflow.core.framework import versions_pb2
- from tensorflow.python.eager import context
- from tensorflow.python.framework import ops
- from tensorflow.python.framework import versions
-
- sys.path.append(os.path.join(os.path.split(os.path.realpath(__file__))[0], "util"))
-
- import graph_library_pb2
-
-
- def _get_num_args(arg_def, node_def):
- if arg_def.number_attr:
- return node_def.attr[arg_def.number_attr].i
- if arg_def.type_list_attr:
- return len(node_def.attr[arg_def.type_list_attr].list.type)
- if arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID:
- return 1
- raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def)))
-
-
- def is_function(fname):
- """Checks for a function definition with `fname` in the current context."""
- if context.executing_eagerly():
- return context.context().has_function(fname)
- return ops.get_default_graph()._is_function(fname)
-
-
- def create_arg_for_input_nodes(fdef, graph_def, input_shapes):
- """Create arg for input nodes."""
- for i, arg_def in enumerate(fdef.signature.input_arg):
- node_def = graph_def.node.add()
- node_def.name = arg_def.name
- node_def.op = "_Arg"
- node_def.attr["T"].type = arg_def.type
- node_def.attr["index"].i = i
- if input_shapes and input_shapes[i] is not None:
- input_shape = input_shapes[i]
- if not isinstance(input_shape, tensor_shape_pb2.TensorShapeProto):
- input_shape = input_shape.as_proto()
- node_def.attr["shape"].shape.CopyFrom(input_shape)
- arg_attrs = fdef.arg_attr[i].attr
- for k in arg_attrs:
- # Only copy internal attributes. Normal attributes for nodes cannot be
- # applied to these Arg nodes.
- if k.startswith("_"):
- node_def.attr[k].CopyFrom(arg_attrs[k])
-
-
- def create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name):
- """Create retval for output nodes."""
- for i, arg_def in enumerate(fdef.signature.output_arg):
- node_def = graph_def.node.add()
- node_def.name = '{}_Retval'.format(arg_def.name)
- node_def.op = "_Retval"
- node_def.attr["T"].type = arg_def.type
- node_def.attr["index"].i = i
- node_def.attr["op_def"].s = ops.get_default_graph()._get_op_def(node_def.op).SerializeToString()
-
- ret_name = fdef.ret[arg_def.name]
- node_def.input.append(nested_to_flat_tensor_name[ret_name])
-
-
- def updat_input_index(node_def, op_def, nested_to_flat_tensor_name):
- """Update input index."""
- flattened_index = 0
- for arg_def in op_def.output_arg:
- num_args = _get_num_args(arg_def, node_def)
- for i in range(num_args):
- # Map tensor names from "node_name:output_arg_name:index" to
- # "node_name:flattened_index".
- nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i)
- if flattened_index == 0:
- flat_name = node_def.name
- else:
- flat_name = "{}:{}".format(node_def.name, flattened_index)
- nested_to_flat_tensor_name[nested_name] = flat_name
- flattened_index += 1
- control_name = "^" + node_def.name
- nested_to_flat_tensor_name[control_name] = control_name
- return
-
-
- def build_tensor_name(fdef, default_graph):
- """Build tensor name."""
- nested_to_flat_tensor_name = {}
- for arg_def in fdef.signature.input_arg:
- nested_to_flat_tensor_name[arg_def.name] = arg_def.name
- control_name = '^{}'.format(arg_def.name)
- nested_to_flat_tensor_name[control_name] = control_name
-
- op_def = None
- for node_def in fdef.node_def:
- f = default_graph._functions.get(node_def.op, None)
- if f is not None and hasattr(f, "signature"):
- op_def = f.signature
- if node_def.op not in copied_functions:
- # Since this function is referenced as an op type, we have no choice but
- # to copy it into the GraphDef if we want downstream tools to process
- # it.
- graph_def.library.function.add().CopyFrom(f.definition)
- copied_functions.add(node_def.op)
- else:
- op_def = ops.get_default_graph()._get_op_def(node_def.op)
-
- for attr in op_def.attr:
- if attr.type == "func":
- fname = node_def.attr[attr.name].func.name
- if not is_function(fname):
- raise ValueError("%s function not found." % fname)
- elif attr.type == "list(func)":
- for fn in node_def.attr[attr.name].list.func:
- fname = fn.name
- if not is_function(fname):
- raise ValueError("%s function not found." % fname)
-
- # Iterate over output_args in op_def to build the map.
- # Index of the output tensor in the flattened list of *all* output
- # tensors of the op.
- updat_input_index(node_def, op_def, nested_to_flat_tensor_name)
- return nested_to_flat_tensor_name
-
-
- def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True):
- """Convert function def to graph def"""
- graph_def = graph_pb2.GraphDef()
- graph_def.versions.CopyFrom(
- versions_pb2.VersionDef(
- producer=versions.GRAPH_DEF_VERSION,
- min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))
-
- default_graph = ops.get_default_graph()
-
- copied_functions = set()
-
- # Copy *all* functions from outer graph to `graph_def` so that both direct
- # and indirect references are safely handled.
- if copy_functions:
- default_graph._copy_functions_to_graph_def(graph_def, 0)
- for function_name in default_graph._functions.keys():
- copied_functions.add(function_name)
-
- if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
- raise ValueError("Length of input_shapes must match the number of " +
- "input_args. len(input_shapes): {} len(input_arg): {}".
- format(len(input_shapes), len(fdef.signature.input_arg)))
-
- # 1. Create _Arg for input nodes.
- create_arg_for_input_nodes(fdef, graph_def, input_shapes)
-
- # 2. Copy all body NodeDefs to the GraphDef.
- graph_def.node.extend(fdef.node_def)
-
- # 3. Perform the renaming.
-
- # Build the tensor name mapping then flatten the tensor names.
- # See comment on `FunctionDef.node_def` on how the tensor naming in
- # FunctionDefs is different from GraphDefs.
- nested_to_flat_tensor_name = build_tensor_name(fdef, default_graph)
-
- # Update inputs of all nodes in graph.
- for node_def in graph_def.node:
- for i in range(len(node_def.input)):
- node_def.input[i] = nested_to_flat_tensor_name.get(node_def.input[i])
-
- # Create _Retval for output nodes.
- create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name)
-
- return graph_def, nested_to_flat_tensor_name
-
-
- def convert_graphs(filename):
- """Convert graphs."""
- try:
- with tf.io.gfile.GFile(filename, 'rb') as f:
- graph_def = tf.compat.v1.GraphDef()
- graph_def.ParseFromString(f.read())
- tf.import_graph_def(graph_def, name='')
- if len(graph_def.library.function) == 0:
- print("INFO: The input model does not contain a functionDef and does not require conversion.")
- return
- try:
- convert_subgraphs(graph_def, filename)
- except Exception as e:
- print("ERROR: Convert subgraphs failed.", e)
- return
- print("INFO: Convert to subgraphs successfully.")
- except NotFoundError:
- print('ERROR: model file {} does not exist'.format(filename))
- return
-
-
- def convert_subgraphs(graph_def, filename):
- """Convert sub graphs."""
- graph_def_library = graph_library_pb2.GraphDefLibrary()
- for i, fdef in enumerate(graph_def.library.function):
- sub_graph, _ = convert_function_def_to_graph_def(fdef, copy_functions=False)
- print("INFO: Convert FunctionDef, index:{}, name:{}".format(str(i), fdef.signature.name))
- sub_graph_name = '{}.pb'.format(fdef.signature.name)
- result_path = '{}/results'.format(os.path.dirname(os.path.abspath(filename)))
- tf.io.write_graph(sub_graph, result_path, sub_graph_name, as_text=False)
- data = sub_graph.SerializeToString()
- ge_graph_def = graph_library_pb2.GeGraphDef()
- ge_graph_def.name = fdef.signature.name
- ge_graph_def.graph.ParseFromString(data)
- graph_def_library.graph_def.append(ge_graph_def)
- print(graph_def_library.graph_def[i])
-
- # Write to prototxt
- graph_def_file = '{}/graph_def_library.pbtxt'.format(os.path.dirname(os.path.abspath(filename)))
- print("graph_def_file: ", graph_def_file)
- try:
- with open(graph_def_file, "w") as f:
- print(graph_def_library, file=f)
- except IOError:
- print("Could not open file. Creating a new one.")
-
-
- def usage():
- """Print the usage."""
- print(
- '''
- Based on tensorflow 1.15 or later, Python 3
-
- Convert the tensorflow functionDefs in the input model file to single GraphDefs,
- and save the result to the "results" directory and graph_def_library.pbtxt in
- the input file directory.
- The name of the sub graph is same as the name of the corresponding functionDef.
-
- Usage: func2grpah.py <command>
-
- Available commands:
- model (-m) Input model file.
- version (-v) Prints the version of this software.
- help (-h) Prints help for commands.
- '''
- )
-
-
- if __name__ == '__main__':
- model = ''
- try:
- opts, args = getopt.getopt(sys.argv[1:], '-v-h-m:', ['version', 'help', 'model='])
- except getopt.GetoptError:
- print("ERROR: Input parameters is invalid, use '--help' to view the help.")
- for opt_name, opt_value in opts:
- if opt_name in ('-m', '--model'):
- model = opt_value
- print("INFO: Input model file is", model)
- convert_graphs(model)
- elif opt_name in ('-h', '--help'):
- usage()
- break
- elif opt_name in ('-v', '--version'):
- print("version 1.0.0")
- break
- if len(sys.argv) == 1:
- print("INFO: Please specify the input parameters, and use '--help' to view the help.")
|