| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Graphs; | |||
| @@ -11,11 +12,34 @@ namespace Tensorflow.Functions | |||
| /// </summary> | |||
| public class ConcreteFunction : IDisposable | |||
| { | |||
| public string Name => _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||
| IntPtr _handle; | |||
| FuncGraph func_graph; | |||
| public string Name | |||
| { | |||
| get | |||
| { | |||
| if (func_graph != null) | |||
| return func_graph.FuncName; | |||
| return _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||
| } | |||
| } | |||
| public Tensor[] Outputs; | |||
| public Type ReturnType; | |||
| public TensorSpec[] OutputStructure; | |||
| public ConcreteFunction(string name) | |||
| { | |||
| func_graph = new FuncGraph(name); | |||
| } | |||
| public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs) | |||
| { | |||
| func_graph = graph; | |||
| } | |||
| public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | |||
| { | |||
| string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||
| @@ -28,8 +52,8 @@ namespace Tensorflow.Functions | |||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
| _handle = graph.ToGraph(opers, | |||
| new Operation[] { input }, | |||
| new Operation[] { output }, | |||
| new[] { input }, | |||
| new[] { output }, | |||
| null); | |||
| } | |||
| } | |||
| @@ -48,8 +72,8 @@ namespace Tensorflow.Functions | |||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
| _handle = graph.ToGraph(opers, | |||
| new Operation[] { input }, | |||
| new Operation[] { output.variant_tensor.op }, | |||
| new[] { input }, | |||
| new[] { output.variant_tensor }, | |||
| null); | |||
| } | |||
| } | |||
| @@ -72,12 +96,38 @@ namespace Tensorflow.Functions | |||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
| _handle = graph.ToGraph(opers, | |||
| new Operation[] { input1, input2, input3 }, | |||
| new Operation[] { outputs.Item1.op, outputs.Item2.op }, | |||
| new[] { input1, input2, input3 }, | |||
| new[] { outputs.Item1, outputs.Item2 }, | |||
| null); | |||
| } | |||
| } | |||
| public void ToGraph(Tensors inputs, Tensors outputs) | |||
| { | |||
| var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
| _handle = func_graph.ToGraph(opers, | |||
| inputs, | |||
| outputs, | |||
| null); | |||
| } | |||
| public Tensors Invoke(Tensors inputs) | |||
| { | |||
| var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly()); | |||
| var (forward_function, args_with_tangents) = forward_backward.Forward(); | |||
| Tensors flat_outputs = null; | |||
| if (tf.Context.executing_eagerly()) | |||
| flat_outputs = forward_function.Call(args_with_tangents); | |||
| forward_backward.Record(flat_outputs); | |||
| return flat_outputs; | |||
| } | |||
| ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||
| { | |||
| var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
| return new ForwardBackwardCall(functions, args, tape_watching: true); | |||
| } | |||
| public void Dispose() | |||
| { | |||
| c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); | |||
| @@ -0,0 +1,44 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Graphs; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| public class EagerDefinedFunction | |||
| { | |||
| public int _num_outputs; | |||
| public string Name => _func_graph.FuncName; | |||
| FuncGraph _func_graph; | |||
| public EagerDefinedFunction(string name, FuncGraph graph, | |||
| Tensors inputs, Tensors outputs, | |||
| Dictionary<string, string> attrs) | |||
| { | |||
| _num_outputs = outputs.Length; | |||
| var input_ops = inputs.Select(x => x.op).ToArray(); | |||
| var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) | |||
| .Select(x => x as Operation).ToArray(); | |||
| var output_names = new string[0]; | |||
| _func_graph = new FuncGraph(graph, name, attrs); | |||
| _func_graph.ToGraph(operations, inputs, outputs, output_names); | |||
| } | |||
| public Tensors Call(Tensors args) | |||
| { | |||
| var results = tf.Runner.TFE_Execute(tf.Context, | |||
| tf.Context.DeviceName, | |||
| _func_graph.FuncName, | |||
| args, | |||
| null, | |||
| _num_outputs); | |||
| return results; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,25 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Graphs; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| public class FirstOrderTapeGradientFunctions : TapeGradientFunctions | |||
| { | |||
| public FirstOrderTapeGradientFunctions(FuncGraph func_graph, | |||
| bool need_gradients_for_jvps) : base(func_graph, | |||
| need_gradients_for_jvps) | |||
| { | |||
| } | |||
| public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||
| { | |||
| var outputs = _func_graph.Outputs; | |||
| (_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) | |||
| = BuildFunctionsForOutputs(outputs, inference_args); | |||
| return _forward; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,38 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| /// <summary> | |||
| /// Holds the state of a function call between execution and recording. | |||
| /// </summary> | |||
| public class ForwardBackwardCall | |||
| { | |||
| TapeGradientFunctions _functions; | |||
| Tensors _inference_args; | |||
| Tensors _input_tangents; | |||
| bool _tape_watching; | |||
| public ForwardBackwardCall(TapeGradientFunctions functions, | |||
| Tensors inference_args, | |||
| bool tape_watching) | |||
| { | |||
| _functions = functions; | |||
| _inference_args = inference_args; | |||
| _tape_watching = tape_watching; | |||
| } | |||
| public (EagerDefinedFunction, Tensors) Forward() | |||
| { | |||
| var forward_function = _functions.Forward(_inference_args); | |||
| return (forward_function, _inference_args); | |||
| } | |||
| public void Record(Tensors flat_outputs) | |||
| { | |||
| if (_tape_watching && flat_outputs != null) | |||
| _functions.Record(flat_outputs, _inference_args); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,120 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Graphs; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.tensorflow; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| /// <summary> | |||
| /// Caches forward and backward functions compatible with eager gradients. | |||
| /// </summary> | |||
| public abstract class TapeGradientFunctions | |||
| { | |||
| string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||
| string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||
| string _FORWARD_PREFIX = "__forward_"; | |||
| string _BACKWARD_PREFIX = "__backward_"; | |||
| string _INFERENCE_PREFIX = "__inference_"; | |||
| protected FuncGraph _func_graph; | |||
| protected EagerDefinedFunction _forward; | |||
| protected FuncGraph _forward_graph; | |||
| protected List<int> _forwardprop_output_indices; | |||
| protected int _num_forwardprop_outputs; | |||
| protected ConcreteFunction _backward; | |||
| public TapeGradientFunctions(FuncGraph func_graph, | |||
| bool need_gradients_for_jvps) | |||
| { | |||
| _func_graph = func_graph; | |||
| } | |||
| public EagerDefinedFunction Forward(Tensors inference_args) | |||
| { | |||
| return ForwardAndBackwardFunctions(inference_args); | |||
| } | |||
| /// <summary> | |||
| /// Record the function call operation. | |||
| /// </summary> | |||
| /// <param name="flat_outputs"></param> | |||
| /// <param name="inference_args"></param> | |||
| public void Record(Tensors flat_outputs, Tensors inference_args) | |||
| { | |||
| var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); | |||
| tf.Runner.RecordGradient(_forward.Name, flat_outputs, new object[0], inference_args, | |||
| getBackwardFunction: () => backward_function); | |||
| } | |||
| (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors flat_outputs) | |||
| { | |||
| BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
| { | |||
| return new Tensor[0]; | |||
| /*var gradients = ops.gradientFunctions[op_name](new EagerOperation | |||
| { | |||
| Name = op_name, | |||
| NumInputs = op_inputs.Length, | |||
| Inputs = op_inputs, | |||
| NumOutputs = op_outputs.Length, | |||
| Outputs = op_outputs, | |||
| SkipInputIndices = unneeded_gradients, | |||
| Attrs = attrs | |||
| }, output_grads); | |||
| return gradients;*/ | |||
| }; | |||
| return (_backward_function_wrapper, flat_outputs); | |||
| } | |||
| protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||
| BuildFunctionsForOutputs(Tensors outputs, Tensors inference_args) | |||
| { | |||
| var trainable_outputs = new List<Tensor>(); | |||
| var trainable_indices = new List<int>(); | |||
| foreach(var (index, output) in enumerate(outputs)) | |||
| { | |||
| if (gradients_util.IsTrainable(output)) | |||
| { | |||
| trainable_outputs.Add(output); | |||
| trainable_indices.Add(index); | |||
| } | |||
| } | |||
| var gradients_wrt_outputs = new List<Tensor>(); | |||
| var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"); | |||
| foreach (var output in trainable_outputs) | |||
| gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | |||
| var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | |||
| _func_graph.Inputs, | |||
| grad_ys: gradients_wrt_outputs.ToArray(), | |||
| src_graph: _func_graph); | |||
| tf.Context.restore_mode(); | |||
| var forward_function_name = $"{_FORWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"; | |||
| var backward_function_attr = new Dictionary<string, string>(); | |||
| backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||
| backwards_graph.Inputs = gradients_wrt_outputs; | |||
| backwards_graph.Outputs = gradients_wrt_inputs; | |||
| var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); | |||
| var forward_function_attr = new Dictionary<string, string>(); | |||
| forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; | |||
| var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, | |||
| _func_graph.Inputs, _func_graph.Outputs, forward_function_attr); | |||
| return (forward_function, _func_graph, backward_function, null, 0); | |||
| } | |||
| public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| } | |||
| @@ -47,6 +47,9 @@ namespace Tensorflow | |||
| string description, | |||
| SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_FunctionSetAttrValueProto(IntPtr func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_FunctionName(IntPtr func); | |||
| @@ -13,8 +13,6 @@ namespace Tensorflow.Gradients | |||
| void RecordOperation(string op_type, | |||
| Tensor[] input_tensors, | |||
| TapeTensor[] output_tensors, | |||
| long[] input_tensor_id, | |||
| TF_DataType[] input_dtypes, | |||
| Func<BackwardFunction> backward_function_getter); | |||
| void VariableAccessed(ResourceVariable variable); | |||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.tensorflow; | |||
| using static Tensorflow.Binding; | |||
| using System.Linq; | |||
| namespace Tensorflow.Gradients | |||
| { | |||
| @@ -14,18 +15,19 @@ namespace Tensorflow.Gradients | |||
| public void RecordOperation(string op_type, | |||
| Tensor[] input_tensors, | |||
| TapeTensor[] output_tensors, | |||
| long[] input_tensor_id, | |||
| TF_DataType[] input_dtypes, | |||
| Func<BackwardFunction> backward_function_getter) | |||
| { | |||
| if (!ShouldRecord(input_tensor_id, input_dtypes)) | |||
| var input_ids = input_tensors.Select(x => x.Id).ToArray(); | |||
| var input_dtypes = input_tensors.Select(x => x.dtype).ToArray(); | |||
| if (!ShouldRecord(input_ids, input_dtypes)) | |||
| { | |||
| return; | |||
| } | |||
| long op_id = next_op_id_++; | |||
| var ids = new List<long>(input_tensor_id.Length); | |||
| foreach (var i in input_tensor_id) | |||
| var ids = new List<long>(input_ids.Length); | |||
| foreach (var i in input_ids) | |||
| { | |||
| tensor_usage_[i]++; | |||
| ids.Add(i); | |||
| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Operations.ControlFlows; | |||
| using static Tensorflow.Binding; | |||
| @@ -39,7 +40,14 @@ namespace Tensorflow | |||
| // If src_graph is a _FuncGraph (i.e. a function body), gather it and all | |||
| // ancestor graphs. This is necessary for correctly handling captured values. | |||
| var func_graphs = new List<FuncGraph>(); | |||
| var curr_graph = src_graph; | |||
| if (src_graph is FuncGraph func_graph) | |||
| { | |||
| func_graphs.append(func_graph); | |||
| curr_graph = func_graph.OuterGraph; | |||
| } | |||
| if (stop_gradients == null) | |||
| stop_gradients = new Tensor[0]; | |||
| @@ -84,7 +92,7 @@ namespace Tensorflow | |||
| var to_ops = ys.Select(x => x.op).ToList(); | |||
| var from_ops = xs.Select(x => x.op).ToList(); | |||
| var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); | |||
| (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs); | |||
| (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs , xs); | |||
| // Add the initial gradients for the ys. | |||
| foreach (var (y, grad_y) in zip(ys, grad_ys)) | |||
| @@ -258,11 +266,8 @@ namespace Tensorflow | |||
| { | |||
| var new_grad_ys = new List<Tensor>(); | |||
| for (int i = 0; i < grad_ys.Length; i++) | |||
| foreach(var (i, (y, grad_y)) in enumerate(zip(ys, grad_ys))) | |||
| { | |||
| var grad_y = grad_ys[i]; | |||
| var y = ys[i]; | |||
| _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops); | |||
| if (grad_y == null) | |||
| @@ -272,8 +277,17 @@ namespace Tensorflow | |||
| var shape = array_ops.shape(y); | |||
| var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}"); | |||
| var fill = gen_array_ops.fill(shape, constant); | |||
| new_grad_ys.Add(fill); | |||
| new_grad_ys.append(fill); | |||
| continue; | |||
| } | |||
| if (y.dtype.is_floating() || y.dtype.is_integer()) | |||
| { | |||
| } | |||
| // Create a grad_y tensor in the name scope of the gradient. | |||
| new_grad_ys.append(array_ops.identity(grad_y, name: $"grad_ys_{i}")); | |||
| } | |||
| return new_grad_ys.ToArray(); | |||
| @@ -294,7 +308,11 @@ namespace Tensorflow | |||
| /// <param name="colocate_gradients_with_ops"></param> | |||
| /// <param name="func_graphs"></param> | |||
| /// <param name="xs"></param> | |||
| private static (Operation[], Dictionary<string, int>, ControlFlowState) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs) | |||
| private static (Operation[], Dictionary<string, int>, ControlFlowState) _PendingCount(List<Operation> to_ops, | |||
| List<Operation> from_ops, | |||
| bool colocate_gradients_with_ops, | |||
| List<FuncGraph> func_graphs, | |||
| Tensor[] xs) | |||
| { | |||
| // Mark reachable ops from from_ops. | |||
| var reached_ops = new List<Operation>(); | |||
| @@ -511,7 +529,7 @@ namespace Tensorflow | |||
| /// <param name="from_ops"></param> | |||
| /// <param name="reached_ops"></param> | |||
| /// <param name="func_graphs"></param> | |||
| private static void _MarkReachedOps(List<Operation> from_ops, List<Operation> reached_ops, List<object> func_graphs) | |||
| private static void _MarkReachedOps(List<Operation> from_ops, List<Operation> reached_ops, List<FuncGraph> func_graphs) | |||
| { | |||
| Queue<Operation> queue = new Queue<Operation>(from_ops); | |||
| while (queue.Count > 0) | |||
| @@ -538,7 +556,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <param name="t"></param> | |||
| /// <param name="func_graphs"></param> | |||
| private static Operation[] _Consumers(Tensor t, List<object> func_graphs) | |||
| private static Operation[] _Consumers(Tensor t, List<FuncGraph> func_graphs) | |||
| { | |||
| return t.consumers(); | |||
| } | |||
| @@ -647,7 +665,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private static bool IsTrainable(Tensor tensor) | |||
| public static bool IsTrainable(Tensor tensor) | |||
| { | |||
| var dtype = tensor.dtype.as_base_dtype(); | |||
| return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64, | |||
| @@ -18,10 +18,11 @@ namespace Tensorflow.Graphs | |||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
| var func_handle = graph.ToGraph(opers, | |||
| new Operation[] { input }, | |||
| new Operation[] { output }, | |||
| new[] { input }, | |||
| new[] { output }, | |||
| null); | |||
| } | |||
| return (Tensor input) => | |||
| { | |||
| @@ -48,11 +49,11 @@ namespace Tensorflow.Graphs | |||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
| var func_handle = graph.ToGraph(opers, | |||
| new Operation[] { input1, input2 }, | |||
| new Operation[] { output }, | |||
| new[] { input1, input2 }, | |||
| new[] { output }, | |||
| null); | |||
| } | |||
| return (Tensor a, Tensor b) => | |||
| { | |||
| var result = tf.Runner.TFE_Execute(tf.Context, | |||
| @@ -3,6 +3,7 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Functions; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Graphs | |||
| @@ -10,10 +11,10 @@ namespace Tensorflow.Graphs | |||
| [AllowChangingInputArguments] | |||
| public sealed class AutoGraphAttribute : OnMethodBoundaryAspect | |||
| { | |||
| FuncGraph graph; | |||
| ConcreteFunction function; | |||
| Tensors originalInputs; | |||
| string func_name; | |||
| static Dictionary<string, Func<Tensors, Tensors>> functions = new Dictionary<string, Func<Tensors, Tensors>>(); | |||
| static Dictionary<string, ConcreteFunction> functions = new Dictionary<string, ConcreteFunction>(); | |||
| public override void OnEntry(MethodExecutionArgs args) | |||
| { | |||
| @@ -21,22 +22,24 @@ namespace Tensorflow.Graphs | |||
| if (functions.ContainsKey(func_name)) | |||
| { | |||
| function = functions[func_name]; | |||
| if (args.Arguments[0] is Tensors tensor_inputs) | |||
| args.ReturnValue = functions[func_name](tensor_inputs.ToArray()); | |||
| args.ReturnValue = ConvertReturnValue(function.Invoke(tensor_inputs)); | |||
| else | |||
| args.ReturnValue = functions[func_name](args.Arguments.Select(x => x as Tensor).ToArray()); | |||
| args.ReturnValue = ConvertReturnValue(function.Invoke(args.Arguments.Select(x => x as Tensor).ToArray())); | |||
| args.FlowBehavior = FlowBehavior.Return; | |||
| return; | |||
| } | |||
| // make function as an Operation by autograph | |||
| graph = new FuncGraph(func_name); | |||
| // need to restore mode when exits | |||
| function = new ConcreteFunction(func_name); | |||
| // convert to Tensors | |||
| if (args.Arguments[0] is Tensors inputs) | |||
| { | |||
| originalInputs = inputs; | |||
| var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.TensorShape)).ToArray(); | |||
| var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.TensorShape, name: "inputs")).ToArray(); | |||
| args.Arguments[0] = new Tensors(new_inputs); | |||
| } | |||
| else | |||
| @@ -48,7 +51,7 @@ namespace Tensorflow.Graphs | |||
| if (args.Arguments[i] is EagerTensor tensor) | |||
| { | |||
| originalInputs[i] = tensor; | |||
| args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape); | |||
| args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape, name: "inputs"); | |||
| } | |||
| } | |||
| } | |||
| @@ -56,58 +59,30 @@ namespace Tensorflow.Graphs | |||
| public override void OnExit(MethodExecutionArgs args) | |||
| { | |||
| var returnValue = mark_as_return(args.ReturnValue as Tensors); | |||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
| if (args.ReturnValue is Tensors outputs) | |||
| { | |||
| if (args.Arguments[0] is Tensors inputs) | |||
| { | |||
| graph.ToGraph(opers, | |||
| inputs.Select(x => x.op).ToArray(), | |||
| outputs.Select(x => x.op).ToArray(), | |||
| null); | |||
| } | |||
| function.ToGraph(inputs, outputs); | |||
| else | |||
| { | |||
| graph.ToGraph(opers, | |||
| args.Arguments.Select(x => (x as Tensor).op).ToArray(), | |||
| outputs.Select(x => x.op).ToArray(), | |||
| null); | |||
| } | |||
| function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), outputs); | |||
| } | |||
| else | |||
| { | |||
| graph.ToGraph(opers, | |||
| args.Arguments.Select(x => (x as Tensor).op).ToArray(), | |||
| new Operation[] { (args.ReturnValue as Tensor).op }, | |||
| null); | |||
| } | |||
| graph.Dispose(); | |||
| function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor); | |||
| Func<Tensors, Tensors> function = (x) => | |||
| { | |||
| var result = tf.Runner.TFE_Execute(tf.Context, | |||
| tf.Context.DeviceName, | |||
| func_name, | |||
| x, | |||
| null, | |||
| 1); | |||
| return result[0]; | |||
| }; | |||
| // cache function. | |||
| function.ReturnType = args.ReturnValue.GetType(); | |||
| functions[func_name] = function; | |||
| // run function | |||
| args.ReturnValue = function(originalInputs); | |||
| args.ReturnValue = ConvertReturnValue(function.Invoke(originalInputs)); | |||
| } | |||
| Tensor mark_as_return(Tensor tensor) | |||
| object ConvertReturnValue(Tensors tensors) | |||
| { | |||
| return array_ops.identity(tensor); | |||
| if (function.ReturnType == typeof(Tensor)) | |||
| return (Tensor)tensors; | |||
| else | |||
| return tensors; | |||
| } | |||
| } | |||
| } | |||
| @@ -10,13 +10,18 @@ namespace Tensorflow.Graphs | |||
| /// </summary> | |||
| public class FuncGraph : Graph | |||
| { | |||
| List<Operation> inputs; | |||
| List<Operation> outputs; | |||
| Graph outer_graph; | |||
| public Graph OuterGraph => outer_graph; | |||
| string func_name; | |||
| // _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||
| IntPtr func_handle; | |||
| public string FuncName => func_name; | |||
| public Tensors Inputs { get; set; } | |||
| public Tensors Outputs { get; set; } | |||
| /// <summary> | |||
| /// Construct a new FuncGraph. | |||
| /// </summary> | |||
| @@ -29,8 +34,17 @@ namespace Tensorflow.Graphs | |||
| as_default(); | |||
| } | |||
| public FuncGraph(IntPtr handle, string name) | |||
| { | |||
| outer_graph = ops.get_default_graph(); | |||
| func_name = name; | |||
| tf.Context.graph_mode(); | |||
| as_default(); | |||
| } | |||
| public IntPtr ToGraph(Operation[] opers, | |||
| Operation[] inputs, Operation[] outputs, | |||
| Tensor[] inputs, Tensor[] outputs, | |||
| string[] output_names) | |||
| { | |||
| using var status = new Status(); | |||
| @@ -40,9 +54,9 @@ namespace Tensorflow.Graphs | |||
| opers.Length, | |||
| opers.Select(x => (IntPtr)x).ToArray(), | |||
| inputs.Length, | |||
| inputs.Select(x => new TF_Output(x, 0)).ToArray(), | |||
| inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
| outputs.Length, | |||
| outputs.Select(x => new TF_Output(x, 0)).ToArray(), | |||
| outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
| output_names == null || output_names.Length == 0 ? null : output_names, | |||
| IntPtr.Zero, | |||
| null, | |||
| @@ -57,13 +71,18 @@ namespace Tensorflow.Graphs | |||
| func_name = c_api.StringPiece(c_api.TF_FunctionName(func_handle)); | |||
| Inputs = inputs; | |||
| // mark_as_return | |||
| Outputs = outputs.Select(x => array_ops.identity(x)).ToArray(); | |||
| tf.Context.restore_mode(); | |||
| return func_handle; | |||
| } | |||
| protected override void DisposeManagedResources() | |||
| { | |||
| base.DisposeManagedResources(); | |||
| tf.Context.restore_mode(); | |||
| } | |||
| } | |||
| } | |||
| @@ -69,7 +69,7 @@ namespace Tensorflow | |||
| throw new ValueError($"Could not find operation \"{operName}\" inside graph \"{_graph_key}\"."); | |||
| var defaultKey = tf.get_default_graph().graph_key; | |||
| if (graph_key != defaultKey) | |||
| if (tf.get_default_graph().GetType().Name == "Graph" && graph_key != defaultKey) | |||
| { | |||
| //Console.WriteLine($"Current graph is not default graph."); | |||
| throw new RuntimeError($"Current graph is not default graph. Default Graph Key: {defaultKey}, Current Graph Key: {graph_key}"); | |||
| @@ -218,6 +218,8 @@ namespace Tensorflow | |||
| { | |||
| case nameof(Int32): | |||
| return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); | |||
| case nameof(Int64): | |||
| return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); | |||
| default: | |||
| return null; | |||
| } | |||
| @@ -235,6 +235,14 @@ namespace Tensorflow | |||
| } | |||
| var _op = tf.OpDefLib._apply_op_helper("Identity", name, new { input }); | |||
| if (tf.Runner.MustRecordGradient()) | |||
| { | |||
| tf.Runner.RecordGradient("Identity", _op.inputs, new object[] | |||
| { | |||
| "T", _op.get_attr<TF_DataType>("T") | |||
| }, _op.outputs); | |||
| } | |||
| return _op.output; | |||
| } | |||
| @@ -632,8 +640,8 @@ namespace Tensorflow | |||
| public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, | |||
| int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, | |||
| int shrink_axis_mask = 0, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new | |||
| => tf.Context.RunInAutoMode2( | |||
| () => tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new | |||
| { | |||
| shape, | |||
| begin, | |||
| @@ -645,8 +653,8 @@ namespace Tensorflow | |||
| ellipsis_mask, | |||
| new_axis_mask, | |||
| shrink_axis_mask | |||
| }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| }).output, | |||
| () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "StridedSliceGrad", name, | |||
| null, | |||
| shape, begin, end, strides, dy, | |||
| @@ -654,8 +662,22 @@ namespace Tensorflow | |||
| "end_mask", end_mask, | |||
| "ellipsis_mask", ellipsis_mask, | |||
| "new_axis_mask", new_axis_mask, | |||
| "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||
| shape, begin, end, strides, dy); | |||
| "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||
| (op) => | |||
| { | |||
| var attrs = new object[] | |||
| { | |||
| "T", op.get_attr<TF_DataType>("T"), | |||
| "Index", op.get_attr<TF_DataType>("Index"), | |||
| "begin_mask", op.get_attr<long>("begin_mask"), | |||
| "end_mask", op.get_attr<long>("end_mask"), | |||
| "ellipsis_mask", op.get_attr<long>("ellipsis_mask"), | |||
| "new_axis_mask", op.get_attr<long>("new_axis_mask"), | |||
| "shrink_axis_mask", op.get_attr<long>("shrink_axis_mask") | |||
| }; | |||
| tf.Runner.RecordGradient("StridedSliceGrad", op.inputs, attrs, op.outputs); | |||
| }, | |||
| new Tensors(shape, begin, end, strides, dy)); | |||
| /// <summary> | |||
| /// Removes dimensions of size 1 from the shape of a tensor. | |||
| @@ -23,6 +23,8 @@ using Tensorflow.Gradients; | |||
| namespace Tensorflow | |||
| { | |||
| public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); | |||
| public partial class tensorflow : ITensorFlowObject | |||
| { | |||
| public TF_DataType byte8 = TF_DataType.TF_UINT8; | |||
| @@ -37,8 +39,6 @@ namespace Tensorflow | |||
| public TF_DataType chars = TF_DataType.TF_STRING; | |||
| public TF_DataType @string = TF_DataType.TF_STRING; | |||
| public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); | |||
| public Status Status; | |||
| public OpDefLibrary OpDefLib; | |||
| public Context Context; | |||
| @@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Engine | |||
| // Record the gradient because custom-made ops don't go through the | |||
| // code-gen'd eager call path | |||
| var op_type = op.node_def.Name; | |||
| var op_type = op.node_def.Op; | |||
| tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); | |||
| @@ -1,9 +1,9 @@ | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using static Tensorflow.KerasApi; | |||
| using static Tensorflow.Binding; | |||
| using System.Collections.Generic; | |||
| using System; | |||
| using System.Linq; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -26,9 +26,20 @@ namespace Tensorflow.Keras.Layers | |||
| var result = array_ops.reshape(inputs, shape.ToArray()); | |||
| if (!tf.Context.executing_eagerly()) | |||
| // result = result.set_shape(compute_output_shape(inputs.shape)); | |||
| throw new NotImplementedException(""); | |||
| result.set_shape(compute_output_shape(inputs.shape)); | |||
| return result; | |||
| } | |||
| TensorShape compute_output_shape(TensorShape input_shape) | |||
| { | |||
| if (input_shape.dims[0] == -1) | |||
| { | |||
| input_shape = input_shape.dims[0]; | |||
| var output_shape = input_shape.concatenate(args.TargetShape.dims); | |||
| return output_shape; | |||
| } | |||
| else | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using Tensorflow; | |||
| using Tensorflow.Graphs; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.ManagedAPI | |||
| @@ -36,7 +37,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
| /// <param name="a"></param> | |||
| /// <param name="b"></param> | |||
| /// <returns></returns> | |||
| // [AutoGraph] | |||
| [AutoGraph] | |||
| Tensor Mul(Tensor a, Tensor b) | |||
| { | |||
| return a * b; | |||