You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

EagerDefinedFunction.cs 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. using Google.Protobuf;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Text;
  7. using Tensorflow.Contexts;
  8. using Tensorflow.Eager;
  9. using Tensorflow.Graphs;
  10. using Tensorflow.Operations;
  11. using Tensorflow.Util;
  12. using Tensorflow.Common.Extensions;
  13. using static Tensorflow.Binding;
  14. using Tensorflow.Framework;
  15. using System.Buffers;
  16. using Tensorflow.Gradients;
  17. namespace Tensorflow.Functions
  18. {
  19. public class EagerDefinedFunction: IDisposable
  20. {
  21. public int _num_outputs;
  22. FuncGraph _graph;
  23. FunctionDef _definition;
  24. OpDef _signature;
  25. string _name;
  26. internal ScopedTFFunction _c_func;
  27. internal Tensor[] _func_graph_outputs;
  28. internal string _grad_func_name;
  29. internal Func<Operation, Tensor[], Tensor[]> csharp_grad_func;
  30. internal EagerDefinedFunction _grad_func;
  31. internal bool _registered_on_context = false;
  32. public string Name => _name;
  33. public DataType[] OutputTypes { get; protected set; }
  34. public Shape[] OutputShapes { get; protected set; }
  35. public FunctionDef Definition
  36. {
  37. get
  38. {
  39. if(_definition is null)
  40. {
  41. _definition = _get_definition();
  42. }
  43. return _definition;
  44. }
  45. }
  46. public OpDef Signature
  47. {
  48. get
  49. {
  50. if( _signature is null)
  51. {
  52. _signature = Definition.Signature;
  53. }
  54. return _signature;
  55. }
  56. }
  57. public unsafe EagerDefinedFunction(string name, FuncGraph graph,
  58. Tensors inputs, Tensors outputs,
  59. Dictionary<string, AttrValue> attrs)
  60. {
  61. var input_ops = inputs.Select(x => x.op).ToArray();
  62. var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op))
  63. .Select(x => x as Operation).ToArray();
  64. var graph_output_names = graph._output_names;
  65. string[] output_names;
  66. if(graph_output_names is not null && outputs.All(t => graph_output_names.ContainsKey(ops.tensor_id(t))))
  67. {
  68. output_names = outputs.Select(t => graph_output_names[ops.tensor_id(t)]).ToArray();
  69. if(output_names.Distinct().Count() != output_names.Length)
  70. {
  71. output_names = new string[0];
  72. }
  73. }
  74. else
  75. {
  76. output_names = new string[0];
  77. }
  78. Status status = new Status();
  79. var fn = c_api.TF_GraphToFunction(graph.c_graph,
  80. name,
  81. false,
  82. operations.Length,
  83. operations.Length == 0 ? new IntPtr[0] : operations.Select(x => (IntPtr)x).ToArray(),
  84. inputs.Length,
  85. inputs.Select(t => t._as_tf_output()).ToArray(),
  86. outputs.Length,
  87. outputs.Select(t => t._as_tf_output()).ToArray(),
  88. output_names.Length != outputs.Length ? null : output_names,
  89. IntPtr.Zero, // warning: the control output hasbben totally ignored.
  90. null,
  91. status);
  92. status.Check(true);
  93. _c_func = new ScopedTFFunction(fn, name);
  94. foreach(var (attr_name, attr_value) in attrs)
  95. {
  96. var serialized = attr_value.ToByteArray();
  97. c_api.TF_FunctionSetAttrValueProto(fn, attr_name, serialized, serialized.Length, status);
  98. status.Check(true);
  99. }
  100. var signature = _get_definition().Signature;
  101. _name = signature.Name;
  102. tf_with(ops.init_scope(), s =>
  103. {
  104. tf.Context.add_function(fn);
  105. _registered_on_context = true;
  106. });
  107. _num_outputs = signature.OutputArg.Count;
  108. OutputTypes = signature.OutputArg.Select(x => x.Type).ToArray();
  109. OutputShapes = outputs.Select(x => x.shape).ToArray();
  110. _func_graph_outputs = new List<Tensor>(outputs).ToArray();
  111. csharp_grad_func = null;
  112. _graph = graph;
  113. }
  114. public unsafe Tensors Call(Tensors args)
  115. {
  116. // TODO(Rinne): Add arg `CancellationManager`.
  117. // TODO(Rinne): Check the arg length.
  118. var function_call_options = tf.Context.FunctionCallOptions;
  119. string config;
  120. if (function_call_options.config_proto_serialized().Length == 0)
  121. {
  122. config = function_utils.get_disabled_rewriter_config().ToString();
  123. }
  124. else
  125. {
  126. config = function_call_options.config_proto_serialized().ToString();
  127. }
  128. config = ""; // TODO(Rinne): revise it.
  129. string executor_type = function_call_options.ExecutorType ?? "";
  130. var executing_eagerly = tf.Context.executing_eagerly();
  131. var attrs = new object[]
  132. {
  133. "executor_type", executor_type,
  134. "config_proto", config
  135. };
  136. Tensor[] outputs;
  137. if (executing_eagerly)
  138. {
  139. outputs = execute.executes(
  140. Signature.Name,
  141. _num_outputs,
  142. args,
  143. attrs,
  144. tf.Context);
  145. }
  146. else
  147. {
  148. if(tf.GetTapeSet().Count == 0)
  149. {
  150. outputs = functional_ops.partitioned_call(args, this, OutputTypes,
  151. executing_eagerly, config, "");
  152. }
  153. else
  154. {
  155. var tape = tf.GetTapeSet().Peek();
  156. tape.StopRecord();
  157. outputs = functional_ops.partitioned_call(args, this, OutputTypes,
  158. executing_eagerly, config, "");
  159. tape.StartRecord();
  160. }
  161. }
  162. foreach(var (i, func_graph_output) in enumerate(_func_graph_outputs))
  163. {
  164. handle_data_util.copy_handle_data(func_graph_output, outputs[i]);
  165. }
  166. if (executing_eagerly)
  167. {
  168. return outputs;
  169. }
  170. else
  171. {
  172. foreach(var (i, shape) in enumerate(OutputShapes))
  173. {
  174. outputs[i].shape = shape;
  175. }
  176. return outputs;
  177. }
  178. }
  179. public void AddToGraph(Graph g = null)
  180. {
  181. if(g is null && tf.Context.executing_eagerly())
  182. {
  183. var ctx = tf.Context;
  184. if (!ctx.has_function(this.Name))
  185. {
  186. ctx.add_function_def(Definition);
  187. }
  188. }
  189. else
  190. {
  191. if (!g.IsFunction(Name))
  192. {
  193. g.AddFunction(this);
  194. }
  195. foreach(var f in _graph.Functions.Values)
  196. {
  197. if (!g.IsFunction(f.Name))
  198. {
  199. g.AddFunction(f);
  200. }
  201. }
  202. }
  203. }
  204. private FunctionDef _get_definition()
  205. {
  206. var buffer = c_api_util.tf_buffer();
  207. Status status = new();
  208. c_api.TF_FunctionToFunctionDef(_c_func.Get(), buffer, status);
  209. status.Check(true);
  210. var proto_data = c_api.TF_GetBuffer(buffer);
  211. return FunctionDef.Parser.ParseFrom(proto_data.AsSpan<byte>());
  212. }
  213. public void Dispose()
  214. {
  215. tf.Context.remove_function(Name);
  216. }
  217. }
  218. }