You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ConcreteFunction.cs 5.7 kB

4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Tensorflow.Framework.Models;
  5. using Tensorflow.Graphs;
  6. using static Tensorflow.Binding;
  7. namespace Tensorflow.Functions
  8. {
  9. /// <summary>
  10. ///
  11. /// </summary>
  12. public class ConcreteFunction : IDisposable
  13. {
  14. IntPtr _handle;
  15. FuncGraph func_graph;
  16. public Tensor[] CapturedInputs => func_graph.external_captures;
  17. public string Name
  18. {
  19. get
  20. {
  21. if (func_graph != null)
  22. return func_graph.FuncName;
  23. return _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle));
  24. }
  25. }
  26. public Tensor[] Outputs;
  27. public Type ReturnType;
  28. public TensorSpec[] OutputStructure;
  29. public ConcreteFunction(string name)
  30. {
  31. func_graph = new FuncGraph(name);
  32. func_graph.as_default();
  33. }
  34. public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null)
  35. {
  36. func_graph = graph;
  37. ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray());
  38. }
  39. public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype)
  40. {
  41. string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
  42. // IntPtr func_handle;
  43. using var graph = new FuncGraph(func_name);
  44. graph.as_default();
  45. var input = tf.placeholder(dtype);
  46. var output = func(input);
  47. var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
  48. _handle = graph.ToGraph(opers,
  49. new[] { input },
  50. new[] { output },
  51. null);
  52. }
  53. public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype)
  54. {
  55. string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
  56. // IntPtr func_handle;
  57. using var graph = new FuncGraph(func_name);
  58. graph.as_default();
  59. var input = tf.placeholder(dtype);
  60. var output = func(input);
  61. OutputStructure = output.structure;
  62. var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
  63. _handle = graph.ToGraph(opers,
  64. new[] { input },
  65. new[] { output.variant_tensor },
  66. null);
  67. }
  68. public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func,
  69. TF_DataType[] dtypes, TensorShape[] shapes)
  70. {
  71. string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
  72. // IntPtr func_handle;
  73. using var graph = new FuncGraph(func_name);
  74. graph.as_default();
  75. var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args");
  76. var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args");
  77. var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args");
  78. var outputs = func(input1, (input2, input3));
  79. Outputs = new[] { outputs.Item1, outputs.Item2 };
  80. OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() };
  81. var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
  82. _handle = graph.ToGraph(opers,
  83. new[] { input1, input2, input3 },
  84. new[] { outputs.Item1, outputs.Item2 },
  85. null);
  86. }
  87. public void ToGraph(Tensors inputs, Tensors outputs)
  88. {
  89. var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
  90. _handle = func_graph.ToGraph(opers,
  91. inputs,
  92. outputs,
  93. null);
  94. OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray();
  95. }
  96. public Tensors Invoke(Tensors inputs)
  97. {
  98. var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly());
  99. var (forward_function, args_with_tangents) = forward_backward.Forward();
  100. Tensors flat_outputs = null;
  101. if (tf.Context.executing_eagerly())
  102. flat_outputs = forward_function.Call(args_with_tangents);
  103. forward_backward.Record(flat_outputs);
  104. return flat_outputs;
  105. }
  106. public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs)
  107. {
  108. var new_args = new List<Tensor>();
  109. new_args.AddRange(args);
  110. new_args.AddRange(captured_inputs);
  111. args = new_args.ToArray();
  112. var attrs = new object[]
  113. {
  114. "executor_type", "",
  115. "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
  116. };
  117. return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs);
  118. }
  119. ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)
  120. {
  121. var functions = new FirstOrderTapeGradientFunctions(func_graph, false);
  122. return new ForwardBackwardCall(functions, args, tape_watching: true);
  123. }
  124. public override string ToString()
  125. => Name;
  126. public void Dispose()
  127. {
  128. c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle);
  129. c_api.TF_DeleteFunction(_handle);
  130. }
  131. }
  132. }