using System; using System.Collections.Generic; using System.Linq; using Tensorflow.Framework.Models; using Tensorflow.Graphs; using static Tensorflow.Binding; namespace Tensorflow.Functions { /// /// /// public class ConcreteFunction : IDisposable { IntPtr _handle; FuncGraph func_graph; public Tensor[] CapturedInputs => func_graph.external_captures; public string Name { get { if (func_graph != null) return func_graph.FuncName; return _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); } } public Tensor[] Outputs; public Type ReturnType; public TensorSpec[] OutputStructure; public ConcreteFunction(string name) { func_graph = new FuncGraph(name); func_graph.as_default(); } public ConcreteFunction(FuncGraph graph, Dictionary attrs = null) { func_graph = graph; ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); } public ConcreteFunction(Func func, TF_DataType dtype) { string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; // IntPtr func_handle; using var graph = new FuncGraph(func_name); graph.as_default(); var input = tf.placeholder(dtype); var output = func(input); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); _handle = graph.ToGraph(opers, new[] { input }, new[] { output }, null); } public ConcreteFunction(Func func, TF_DataType dtype) { string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; // IntPtr func_handle; using var graph = new FuncGraph(func_name); graph.as_default(); var input = tf.placeholder(dtype); var output = func(input); OutputStructure = output.structure; var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); _handle = graph.ToGraph(opers, new[] { input }, new[] { output.variant_tensor }, null); } public ConcreteFunction(Func func, TF_DataType[] dtypes, TensorShape[] shapes) { string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; // IntPtr func_handle; using var graph = new FuncGraph(func_name); graph.as_default(); var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args"); var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args"); var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args"); var outputs = func(input1, (input2, input3)); Outputs = new[] { outputs.Item1, outputs.Item2 }; OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() }; var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); _handle = graph.ToGraph(opers, new[] { input1, input2, input3 }, new[] { outputs.Item1, outputs.Item2 }, null); } public void ToGraph(Tensors inputs, Tensors outputs) { var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); _handle = func_graph.ToGraph(opers, inputs, outputs, null); OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray(); } public Tensors Invoke(Tensors inputs) { var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly()); var (forward_function, args_with_tangents) = forward_backward.Forward(); Tensors flat_outputs = null; if (tf.Context.executing_eagerly()) flat_outputs = forward_function.Call(args_with_tangents); forward_backward.Record(flat_outputs); return flat_outputs; } public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs) { var new_args = new List(); new_args.AddRange(args); new_args.AddRange(captured_inputs); args = new_args.ToArray(); var attrs = new object[] { "executor_type", "", "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() }; return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); } ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) { var functions = new FirstOrderTapeGradientFunctions(func_graph, false); return new ForwardBackwardCall(functions, args, tape_watching: true); } public override string ToString() => Name; public void Dispose() { c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); c_api.TF_DeleteFunction(_handle); } } }