You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

func2graph.py 11 kB

4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
3 years ago
5 years ago
4 years ago
5 years ago
3 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. #-------------------------------------------------------------------
  4. # Purpose:
  5. # Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved.
  6. #-------------------------------------------------------------------
  7. """Class to convert function to graph"""
  8. import os
  9. import sys
  10. import getopt
  11. from google.protobuf import text_format
  12. import tensorflow as tf
  13. from tensorflow.python.framework.errors_impl import NotFoundError
  14. from tensorflow.python.platform import gfile
  15. from tensorflow.core.framework import graph_pb2
  16. from tensorflow.core.framework import tensor_shape_pb2
  17. from tensorflow.core.framework import types_pb2
  18. from tensorflow.core.framework import versions_pb2
  19. from tensorflow.python.eager import context
  20. from tensorflow.python.framework import ops
  21. from tensorflow.python.framework import versions
  22. sys.path.append(os.path.join(os.path.split(os.path.realpath(__file__))[0], "util"))
  23. import graph_library_pb2
  24. def _get_num_args(arg_def, node_def):
  25. if arg_def.number_attr:
  26. return node_def.attr[arg_def.number_attr].i
  27. if arg_def.type_list_attr:
  28. return len(node_def.attr[arg_def.type_list_attr].list.type)
  29. if arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID:
  30. return 1
  31. raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def)))
  32. def is_function(fname):
  33. """Checks for a function definition with `fname` in the current context."""
  34. if context.executing_eagerly():
  35. return context.context().has_function(fname)
  36. return ops.get_default_graph()._is_function(fname)
  37. def create_arg_for_input_nodes(fdef, graph_def, input_shapes):
  38. """Create arg for input nodes."""
  39. for i, arg_def in enumerate(fdef.signature.input_arg):
  40. node_def = graph_def.node.add()
  41. node_def.name = arg_def.name
  42. node_def.op = "_Arg"
  43. node_def.attr["T"].type = arg_def.type
  44. node_def.attr["index"].i = i
  45. if input_shapes and input_shapes[i] is not None:
  46. input_shape = input_shapes[i]
  47. if not isinstance(input_shape, tensor_shape_pb2.TensorShapeProto):
  48. input_shape = input_shape.as_proto()
  49. node_def.attr["shape"].shape.CopyFrom(input_shape)
  50. arg_attrs = fdef.arg_attr[i].attr
  51. for k in arg_attrs:
  52. # Only copy internal attributes. Normal attributes for nodes cannot be
  53. # applied to these Arg nodes.
  54. if k.startswith("_"):
  55. node_def.attr[k].CopyFrom(arg_attrs[k])
  56. def create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name):
  57. """Create retval for output nodes."""
  58. for i, arg_def in enumerate(fdef.signature.output_arg):
  59. node_def = graph_def.node.add()
  60. node_def.name = '{}_Retval'.format(arg_def.name)
  61. node_def.op = "_Retval"
  62. node_def.attr["T"].type = arg_def.type
  63. node_def.attr["index"].i = i
  64. node_def.attr["op_def"].s = ops.get_default_graph()._get_op_def(node_def.op).SerializeToString()
  65. ret_name = fdef.ret[arg_def.name]
  66. node_def.input.append(nested_to_flat_tensor_name[ret_name])
  67. def updat_input_index(node_def, op_def, nested_to_flat_tensor_name):
  68. """Update input index."""
  69. flattened_index = 0
  70. for arg_def in op_def.output_arg:
  71. num_args = _get_num_args(arg_def, node_def)
  72. for i in range(num_args):
  73. # Map tensor names from "node_name:output_arg_name:index" to
  74. # "node_name:flattened_index".
  75. nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i)
  76. if flattened_index == 0:
  77. flat_name = node_def.name
  78. else:
  79. flat_name = "{}:{}".format(node_def.name, flattened_index)
  80. nested_to_flat_tensor_name[nested_name] = flat_name
  81. flattened_index += 1
  82. control_name = "^" + node_def.name
  83. nested_to_flat_tensor_name[control_name] = control_name
  84. return
  85. def build_tensor_name(fdef, default_graph):
  86. """Build tensor name."""
  87. nested_to_flat_tensor_name = {}
  88. for arg_def in fdef.signature.input_arg:
  89. nested_to_flat_tensor_name[arg_def.name] = arg_def.name
  90. control_name = '^{}'.format(arg_def.name)
  91. nested_to_flat_tensor_name[control_name] = control_name
  92. op_def = None
  93. for node_def in fdef.node_def:
  94. f = default_graph._functions.get(node_def.op, None)
  95. if f is not None and hasattr(f, "signature"):
  96. op_def = f.signature
  97. if node_def.op not in copied_functions:
  98. # Since this function is referenced as an op type, we have no choice but
  99. # to copy it into the GraphDef if we want downstream tools to process
  100. # it.
  101. graph_def.library.function.add().CopyFrom(f.definition)
  102. copied_functions.add(node_def.op)
  103. else:
  104. op_def = ops.get_default_graph()._get_op_def(node_def.op)
  105. for attr in op_def.attr:
  106. if attr.type == "func":
  107. fname = node_def.attr[attr.name].func.name
  108. if not is_function(fname):
  109. raise ValueError("%s function not found." % fname)
  110. elif attr.type == "list(func)":
  111. for fn in node_def.attr[attr.name].list.func:
  112. fname = fn.name
  113. if not is_function(fname):
  114. raise ValueError("%s function not found." % fname)
  115. # Iterate over output_args in op_def to build the map.
  116. # Index of the output tensor in the flattened list of *all* output
  117. # tensors of the op.
  118. updat_input_index(node_def, op_def, nested_to_flat_tensor_name)
  119. return nested_to_flat_tensor_name
  120. def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True):
  121. """Convert function def to graph def"""
  122. graph_def = graph_pb2.GraphDef()
  123. graph_def.versions.CopyFrom(
  124. versions_pb2.VersionDef(
  125. producer=versions.GRAPH_DEF_VERSION,
  126. min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))
  127. default_graph = ops.get_default_graph()
  128. copied_functions = set()
  129. # Copy *all* functions from outer graph to `graph_def` so that both direct
  130. # and indirect references are safely handled.
  131. if copy_functions:
  132. default_graph._copy_functions_to_graph_def(graph_def, 0)
  133. for function_name in default_graph._functions.keys():
  134. copied_functions.add(function_name)
  135. if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
  136. raise ValueError("Length of input_shapes must match the number of " +
  137. "input_args. len(input_shapes): {} len(input_arg): {}".
  138. format(len(input_shapes), len(fdef.signature.input_arg)))
  139. # 1. Create _Arg for input nodes.
  140. create_arg_for_input_nodes(fdef, graph_def, input_shapes)
  141. # 2. Copy all body NodeDefs to the GraphDef.
  142. graph_def.node.extend(fdef.node_def)
  143. # 3. Perform the renaming.
  144. # Build the tensor name mapping then flatten the tensor names.
  145. # See comment on `FunctionDef.node_def` on how the tensor naming in
  146. # FunctionDefs is different from GraphDefs.
  147. nested_to_flat_tensor_name = build_tensor_name(fdef, default_graph)
  148. # Update inputs of all nodes in graph.
  149. for node_def in graph_def.node:
  150. for i in range(len(node_def.input)):
  151. node_def.input[i] = nested_to_flat_tensor_name.get(node_def.input[i])
  152. # Create _Retval for output nodes.
  153. create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name)
  154. return graph_def, nested_to_flat_tensor_name
  155. def convert_graphs(filename):
  156. """Convert graphs."""
  157. try:
  158. with tf.io.gfile.GFile(filename, 'rb') as f:
  159. graph_def = tf.compat.v1.GraphDef()
  160. graph_def.ParseFromString(f.read())
  161. tf.import_graph_def(graph_def, name='')
  162. if len(graph_def.library.function) == 0:
  163. print("INFO: The input model does not contain a functionDef and does not require conversion.")
  164. return
  165. try:
  166. convert_subgraphs(graph_def, filename)
  167. except Exception as e:
  168. print("ERROR: Convert subgraphs failed.", e)
  169. return
  170. print("INFO: Convert to subgraphs successfully.")
  171. except NotFoundError:
  172. print('ERROR: model file {} does not exist'.format(filename))
  173. return
  174. def convert_subgraphs(graph_def, filename):
  175. """Convert sub graphs."""
  176. graph_def_library = graph_library_pb2.GraphDefLibrary()
  177. for i, fdef in enumerate(graph_def.library.function):
  178. sub_graph, _ = convert_function_def_to_graph_def(fdef, copy_functions=False)
  179. print("INFO: Convert FunctionDef, index:{}, name:{}".format(str(i), fdef.signature.name))
  180. sub_graph_name = '{}.pb'.format(fdef.signature.name)
  181. result_path = '{}/results'.format(os.path.dirname(os.path.abspath(filename)))
  182. tf.io.write_graph(sub_graph, result_path, sub_graph_name, as_text=False)
  183. data = sub_graph.SerializeToString()
  184. ge_graph_def = graph_library_pb2.GeGraphDef()
  185. ge_graph_def.name = fdef.signature.name
  186. ge_graph_def.graph.ParseFromString(data)
  187. graph_def_library.graph_def.append(ge_graph_def)
  188. print(graph_def_library.graph_def[i])
  189. # Write to prototxt
  190. graph_def_file = '{}/graph_def_library.pbtxt'.format(os.path.dirname(os.path.abspath(filename)))
  191. print("graph_def_file: ", graph_def_file)
  192. try:
  193. with open(graph_def_file, "w") as f:
  194. print(graph_def_library, file=f)
  195. except IOError:
  196. print("Could not open file. Creating a new one.")
  197. def usage():
  198. """Print the usage."""
  199. print(
  200. '''
  201. Based on tensorflow 1.15 or later, Python 3
  202. Convert the tensorflow functionDefs in the input model file to single GraphDefs,
  203. and save the result to the "results" directory and graph_def_library.pbtxt in
  204. the input file directory.
  205. The name of the sub graph is same as the name of the corresponding functionDef.
  206. Usage: func2grpah.py <command>
  207. Available commands:
  208. model (-m) Input model file.
  209. version (-v) Prints the version of this software.
  210. help (-h) Prints help for commands.
  211. '''
  212. )
  213. if __name__ == '__main__':
  214. model = ''
  215. try:
  216. opts, args = getopt.getopt(sys.argv[1:], '-v-h-m:', ['version', 'help', 'model='])
  217. except getopt.GetoptError:
  218. print("ERROR: Input parameters is invalid, use '--help' to view the help.")
  219. for opt_name, opt_value in opts:
  220. if opt_name in ('-m', '--model'):
  221. model = opt_value
  222. print("INFO: Input model file is", model)
  223. convert_graphs(model)
  224. elif opt_name in ('-h', '--help'):
  225. usage()
  226. break
  227. elif opt_name in ('-v', '--version'):
  228. print("version 1.0.0")
  229. break
  230. if len(sys.argv) == 1:
  231. print("INFO: Please specify the input parameters, and use '--help' to view the help.")