using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using Tensorflow.Eager; using Tensorflow.Framework.Models; using Tensorflow.Graphs; using Tensorflow.Train; using Tensorflow.Util; using static Tensorflow.Binding; namespace Tensorflow.Functions { /// /// /// public class ConcreteFunction: Trackable { protected IEnumerable _captured_inputs; internal FuncGraph func_graph; protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; protected Dictionary _attrs; protected FunctionSpec _function_spec; protected FunctionSpec _pre_initialized_function_spec = null; protected EagerDefinedFunction _inference_function; internal ForwardBackwardCall forward_backward; public Tensor[] Inputs => func_graph.Inputs; public Tensor[] CapturedInputs => func_graph.external_captures; public string Name => _delayed_rewrite_functions.Forward().Name; public Tensor[] Outputs; public Type ReturnType; public TensorSpec[] OutputStructure; public IEnumerable ArgKeywords { get; set; } public long NumPositionArgs { get; set; } public ConcreteFunction(string name) { func_graph = new FuncGraph(name); _captured_inputs = func_graph.external_captures; _attrs= new Dictionary(); _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); _inference_function = _delayed_rewrite_functions.Forward(); } public ConcreteFunction(FuncGraph graph, Dictionary attrs = null) { func_graph = graph; _captured_inputs = func_graph.external_captures; //ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); _attrs = attrs; _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); _inference_function = _delayed_rewrite_functions.Forward(); } public ConcreteFunction(Func func, TF_DataType dtype) { string func_name = $"{func.Method.Name}_{ops.uid_function()}"; func_graph = new FuncGraph(func_name); func_graph.as_default(); var input = tf.placeholder(dtype); var output = func(input); var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); func_graph.ToGraph(opers, new[] { input }, new[] { output }, null); func_graph.Exit(); _captured_inputs = func_graph.external_captures; _attrs = new Dictionary(); _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); _inference_function = _delayed_rewrite_functions.Forward(); } public ConcreteFunction(Func func, TF_DataType dtype) { string func_name = $"{func.Method.Name}_{ops.uid_function()}"; func_graph = new FuncGraph(func_name); func_graph.as_default(); var input = tf.placeholder(dtype); var output = func(input); OutputStructure = output.structure; var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); func_graph.ToGraph(opers, new[] { input }, new[] { output.variant_tensor }, null); func_graph.Exit(); _captured_inputs = func_graph.external_captures; _attrs = new Dictionary(); _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); _inference_function = _delayed_rewrite_functions.Forward(); } /*public ConcreteFunction(Func func, TF_DataType[] dtypes, Shape[] shapes) { string func_name = $"{func.Method.Name}_{ops.uid_function()}"; // IntPtr func_handle; func_graph = new FuncGraph(func_name); func_graph.as_default(); var inputs = new Tensors(); foreach(var (i, dtype) in enumerate(dtypes)) inputs.Add(tf.placeholder(dtypes[i], shape: shapes[i], name: "args")); Outputs = func(inputs); OutputStructure = Outputs.Select(x => x.ToTensorSpec()).ToArray(); var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); func_graph.ToGraph(opers, inputs, Outputs, null); func_graph.Exit(); }*/ public void ToGraph(Tensors inputs, Tensors outputs) { var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); func_graph.ToGraph(opers, inputs, outputs, null); OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray(); } public void Enter() { func_graph.as_default(); } public void Exit() { func_graph.Exit(); } public Tensors FilteredCall(Tensors inputs) { return CallFlat(inputs, CapturedInputs); } /// /// Executes the wrapped function. /// /// /// /// public Tensors CallFlat(Tensor[] args, Tensor[] captured_inputs) { var executing_eagerly = tf.Context.executing_eagerly(); var default_graph = ops.get_default_graph(); var tensor_inputs = new Tensors(); foreach (var (i, arg) in enumerate(args)) { tensor_inputs.Add(arg); // If we're graph building, shape inference is on. } if (!executing_eagerly) { } tensor_inputs.AddRange(captured_inputs); args = tensor_inputs.ToArray(); var possible_gradient_type = tf.Runner.MustRecordGradient() ? 1 : 0; // No tape is watching; skip to running the function. if (possible_gradient_type == 0 && executing_eagerly) { return _build_call_outputs(_inference_function.Call(args)); //var attrs = new object[] //{ // "executor_type", "", // "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() //}; //return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); } if (forward_backward == null) forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); var (forward_function, args_with_tangents) = forward_backward.Forward(); Tensors flat_outputs = null; if (executing_eagerly) { flat_outputs = forward_function.Call(args_with_tangents); } else { // TODO(Rinne): add `default_graph._override_gradient_function`. flat_outputs = forward_function.Call(args_with_tangents); } forward_backward.Record(flat_outputs); return _build_call_outputs(flat_outputs); } public void AddTograph(Graph? g = null) { if(!tf.Context.executing_eagerly() && g is null) { g = ops.get_default_graph(); } _delayed_rewrite_functions.Forward().AddToGraph(g); } public void SetExternalCaptures(IEnumerable captures) { _captured_inputs = captures; } ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) { TangentInfo input_tangents; if (executing_eagerly) { throw new NotImplementedException(); } else { input_tangents = new TangentInfo(); } if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER) { if(input_tangents.Indices is not null || executing_eagerly) { var functions = new FirstOrderTapeGradientFunctions(func_graph, false); return new ForwardBackwardCall(functions, args, tape_watching: true); } else { return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: true); } } else if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER) { throw new NotImplementedException(); } // TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false); } internal void _set_function_spec(FunctionSpec spec) { _function_spec = null; _pre_initialized_function_spec = spec; _initialize_function_spec(); } internal void _initialize_function_spec() { if(_pre_initialized_function_spec is null) { return; } Debug.Assert(_function_spec is null, "already initialized"); var spec = _pre_initialized_function_spec; //var args = spec.Fullargspec.DictValue.Fields["args"]; // TODO(Rinne): self.structured_input_signature _function_spec = new FunctionSpec() { Fullargspec = spec.Fullargspec, IsMethod = spec.IsMethod, InputSignature = spec.InputSignature }; } private Tensors _build_call_outputs(Tensors result) { // TODO(Rinne): dwal with `func_graph.structured_outputs` return result; } public override string ToString() => Name; } }