|
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using Tensorflow.Framework.Models;
- using Tensorflow.Graphs;
- using static Tensorflow.Binding;
-
- namespace Tensorflow.Functions
- {
- /// <summary>
- ///
- /// </summary>
- public class ConcreteFunction : IDisposable
- {
- IntPtr _handle;
- FuncGraph func_graph;
- public Tensor[] CapturedInputs => func_graph.external_captures;
-
- public string Name
- {
- get
- {
- if (func_graph != null)
- return func_graph.FuncName;
-
- return _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle));
- }
- }
-
- public Tensor[] Outputs;
- public Type ReturnType;
- public TensorSpec[] OutputStructure;
-
- public ConcreteFunction(string name)
- {
- func_graph = new FuncGraph(name);
- func_graph.as_default();
- }
-
- public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null)
- {
- func_graph = graph;
-
- ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray());
- }
-
- public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype)
- {
- string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
-
- // IntPtr func_handle;
- using var graph = new FuncGraph(func_name);
- graph.as_default();
- var input = tf.placeholder(dtype);
- var output = func(input);
-
- var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
- _handle = graph.ToGraph(opers,
- new[] { input },
- new[] { output },
- null);
- }
-
- public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype)
- {
- string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
-
- // IntPtr func_handle;
- using var graph = new FuncGraph(func_name);
- graph.as_default();
-
- var input = tf.placeholder(dtype);
- var output = func(input);
-
- OutputStructure = output.structure;
-
- var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
- _handle = graph.ToGraph(opers,
- new[] { input },
- new[] { output.variant_tensor },
- null);
- }
-
- public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func,
- TF_DataType[] dtypes, TensorShape[] shapes)
- {
- string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
-
- // IntPtr func_handle;
- using var graph = new FuncGraph(func_name);
- graph.as_default();
-
- var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args");
- var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args");
- var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args");
- var outputs = func(input1, (input2, input3));
-
- Outputs = new[] { outputs.Item1, outputs.Item2 };
- OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() };
-
- var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
- _handle = graph.ToGraph(opers,
- new[] { input1, input2, input3 },
- new[] { outputs.Item1, outputs.Item2 },
- null);
- }
-
- public void ToGraph(Tensors inputs, Tensors outputs)
- {
- var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
- _handle = func_graph.ToGraph(opers,
- inputs,
- outputs,
- null);
-
- OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray();
- }
-
- public Tensors Invoke(Tensors inputs)
- {
- var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly());
- var (forward_function, args_with_tangents) = forward_backward.Forward();
- Tensors flat_outputs = null;
- if (tf.Context.executing_eagerly())
- flat_outputs = forward_function.Call(args_with_tangents);
- forward_backward.Record(flat_outputs);
- return flat_outputs;
- }
-
- public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs)
- {
- var new_args = new List<Tensor>();
- new_args.AddRange(args);
- new_args.AddRange(captured_inputs);
- args = new_args.ToArray();
-
- var attrs = new object[]
- {
- "executor_type", "",
- "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
- };
- return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs);
- }
-
- ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)
- {
- var functions = new FirstOrderTapeGradientFunctions(func_graph, false);
- return new ForwardBackwardCall(functions, args, tape_watching: true);
- }
-
- public override string ToString()
- => Name;
-
- public void Dispose()
- {
- c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle);
- c_api.TF_DeleteFunction(_handle);
- }
- }
- }
|