| @@ -1,4 +1,5 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
| using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
| @@ -11,11 +12,34 @@ namespace Tensorflow.Functions | |||||
| /// </summary> | /// </summary> | ||||
| public class ConcreteFunction : IDisposable | public class ConcreteFunction : IDisposable | ||||
| { | { | ||||
| public string Name => _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||||
| IntPtr _handle; | IntPtr _handle; | ||||
| FuncGraph func_graph; | |||||
| public string Name | |||||
| { | |||||
| get | |||||
| { | |||||
| if (func_graph != null) | |||||
| return func_graph.FuncName; | |||||
| return _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||||
| } | |||||
| } | |||||
| public Tensor[] Outputs; | public Tensor[] Outputs; | ||||
| public Type ReturnType; | |||||
| public TensorSpec[] OutputStructure; | public TensorSpec[] OutputStructure; | ||||
| public ConcreteFunction(string name) | |||||
| { | |||||
| func_graph = new FuncGraph(name); | |||||
| } | |||||
| public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs) | |||||
| { | |||||
| func_graph = graph; | |||||
| } | |||||
| public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | ||||
| { | { | ||||
| string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | ||||
| @@ -28,8 +52,8 @@ namespace Tensorflow.Functions | |||||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
| _handle = graph.ToGraph(opers, | _handle = graph.ToGraph(opers, | ||||
| new Operation[] { input }, | |||||
| new Operation[] { output }, | |||||
| new[] { input }, | |||||
| new[] { output }, | |||||
| null); | null); | ||||
| } | } | ||||
| } | } | ||||
| @@ -48,8 +72,8 @@ namespace Tensorflow.Functions | |||||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
| _handle = graph.ToGraph(opers, | _handle = graph.ToGraph(opers, | ||||
| new Operation[] { input }, | |||||
| new Operation[] { output.variant_tensor.op }, | |||||
| new[] { input }, | |||||
| new[] { output.variant_tensor }, | |||||
| null); | null); | ||||
| } | } | ||||
| } | } | ||||
| @@ -72,12 +96,38 @@ namespace Tensorflow.Functions | |||||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
| _handle = graph.ToGraph(opers, | _handle = graph.ToGraph(opers, | ||||
| new Operation[] { input1, input2, input3 }, | |||||
| new Operation[] { outputs.Item1.op, outputs.Item2.op }, | |||||
| new[] { input1, input2, input3 }, | |||||
| new[] { outputs.Item1, outputs.Item2 }, | |||||
| null); | null); | ||||
| } | } | ||||
| } | } | ||||
| public void ToGraph(Tensors inputs, Tensors outputs) | |||||
| { | |||||
| var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||||
| _handle = func_graph.ToGraph(opers, | |||||
| inputs, | |||||
| outputs, | |||||
| null); | |||||
| } | |||||
| public Tensors Invoke(Tensors inputs) | |||||
| { | |||||
| var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly()); | |||||
| var (forward_function, args_with_tangents) = forward_backward.Forward(); | |||||
| Tensors flat_outputs = null; | |||||
| if (tf.Context.executing_eagerly()) | |||||
| flat_outputs = forward_function.Call(args_with_tangents); | |||||
| forward_backward.Record(flat_outputs); | |||||
| return flat_outputs; | |||||
| } | |||||
| ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||||
| { | |||||
| var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||||
| return new ForwardBackwardCall(functions, args, tape_watching: true); | |||||
| } | |||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); | c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); | ||||
| @@ -0,0 +1,44 @@ | |||||
| using Google.Protobuf; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Graphs; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Functions | |||||
| { | |||||
| public class EagerDefinedFunction | |||||
| { | |||||
| public int _num_outputs; | |||||
| public string Name => _func_graph.FuncName; | |||||
| FuncGraph _func_graph; | |||||
| public EagerDefinedFunction(string name, FuncGraph graph, | |||||
| Tensors inputs, Tensors outputs, | |||||
| Dictionary<string, string> attrs) | |||||
| { | |||||
| _num_outputs = outputs.Length; | |||||
| var input_ops = inputs.Select(x => x.op).ToArray(); | |||||
| var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) | |||||
| .Select(x => x as Operation).ToArray(); | |||||
| var output_names = new string[0]; | |||||
| _func_graph = new FuncGraph(graph, name, attrs); | |||||
| _func_graph.ToGraph(operations, inputs, outputs, output_names); | |||||
| } | |||||
| public Tensors Call(Tensors args) | |||||
| { | |||||
| var results = tf.Runner.TFE_Execute(tf.Context, | |||||
| tf.Context.DeviceName, | |||||
| _func_graph.FuncName, | |||||
| args, | |||||
| null, | |||||
| _num_outputs); | |||||
| return results; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,25 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Graphs; | |||||
| namespace Tensorflow.Functions | |||||
| { | |||||
| public class FirstOrderTapeGradientFunctions : TapeGradientFunctions | |||||
| { | |||||
| public FirstOrderTapeGradientFunctions(FuncGraph func_graph, | |||||
| bool need_gradients_for_jvps) : base(func_graph, | |||||
| need_gradients_for_jvps) | |||||
| { | |||||
| } | |||||
| public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||||
| { | |||||
| var outputs = _func_graph.Outputs; | |||||
| (_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) | |||||
| = BuildFunctionsForOutputs(outputs, inference_args); | |||||
| return _forward; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,38 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Functions | |||||
| { | |||||
| /// <summary> | |||||
| /// Holds the state of a function call between execution and recording. | |||||
| /// </summary> | |||||
| public class ForwardBackwardCall | |||||
| { | |||||
| TapeGradientFunctions _functions; | |||||
| Tensors _inference_args; | |||||
| Tensors _input_tangents; | |||||
| bool _tape_watching; | |||||
| public ForwardBackwardCall(TapeGradientFunctions functions, | |||||
| Tensors inference_args, | |||||
| bool tape_watching) | |||||
| { | |||||
| _functions = functions; | |||||
| _inference_args = inference_args; | |||||
| _tape_watching = tape_watching; | |||||
| } | |||||
| public (EagerDefinedFunction, Tensors) Forward() | |||||
| { | |||||
| var forward_function = _functions.Forward(_inference_args); | |||||
| return (forward_function, _inference_args); | |||||
| } | |||||
| public void Record(Tensors flat_outputs) | |||||
| { | |||||
| if (_tape_watching && flat_outputs != null) | |||||
| _functions.Record(flat_outputs, _inference_args); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,120 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Graphs; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.tensorflow; | |||||
| namespace Tensorflow.Functions | |||||
| { | |||||
| /// <summary> | |||||
| /// Caches forward and backward functions compatible with eager gradients. | |||||
| /// </summary> | |||||
| public abstract class TapeGradientFunctions | |||||
| { | |||||
| string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||||
| string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||||
| string _FORWARD_PREFIX = "__forward_"; | |||||
| string _BACKWARD_PREFIX = "__backward_"; | |||||
| string _INFERENCE_PREFIX = "__inference_"; | |||||
| protected FuncGraph _func_graph; | |||||
| protected EagerDefinedFunction _forward; | |||||
| protected FuncGraph _forward_graph; | |||||
| protected List<int> _forwardprop_output_indices; | |||||
| protected int _num_forwardprop_outputs; | |||||
| protected ConcreteFunction _backward; | |||||
| public TapeGradientFunctions(FuncGraph func_graph, | |||||
| bool need_gradients_for_jvps) | |||||
| { | |||||
| _func_graph = func_graph; | |||||
| } | |||||
| public EagerDefinedFunction Forward(Tensors inference_args) | |||||
| { | |||||
| return ForwardAndBackwardFunctions(inference_args); | |||||
| } | |||||
| /// <summary> | |||||
| /// Record the function call operation. | |||||
| /// </summary> | |||||
| /// <param name="flat_outputs"></param> | |||||
| /// <param name="inference_args"></param> | |||||
| public void Record(Tensors flat_outputs, Tensors inference_args) | |||||
| { | |||||
| var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); | |||||
| tf.Runner.RecordGradient(_forward.Name, flat_outputs, new object[0], inference_args, | |||||
| getBackwardFunction: () => backward_function); | |||||
| } | |||||
| (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors flat_outputs) | |||||
| { | |||||
| BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||||
| { | |||||
| return new Tensor[0]; | |||||
| /*var gradients = ops.gradientFunctions[op_name](new EagerOperation | |||||
| { | |||||
| Name = op_name, | |||||
| NumInputs = op_inputs.Length, | |||||
| Inputs = op_inputs, | |||||
| NumOutputs = op_outputs.Length, | |||||
| Outputs = op_outputs, | |||||
| SkipInputIndices = unneeded_gradients, | |||||
| Attrs = attrs | |||||
| }, output_grads); | |||||
| return gradients;*/ | |||||
| }; | |||||
| return (_backward_function_wrapper, flat_outputs); | |||||
| } | |||||
| protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||||
| BuildFunctionsForOutputs(Tensors outputs, Tensors inference_args) | |||||
| { | |||||
| var trainable_outputs = new List<Tensor>(); | |||||
| var trainable_indices = new List<int>(); | |||||
| foreach(var (index, output) in enumerate(outputs)) | |||||
| { | |||||
| if (gradients_util.IsTrainable(output)) | |||||
| { | |||||
| trainable_outputs.Add(output); | |||||
| trainable_indices.Add(index); | |||||
| } | |||||
| } | |||||
| var gradients_wrt_outputs = new List<Tensor>(); | |||||
| var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"); | |||||
| foreach (var output in trainable_outputs) | |||||
| gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | |||||
| var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | |||||
| _func_graph.Inputs, | |||||
| grad_ys: gradients_wrt_outputs.ToArray(), | |||||
| src_graph: _func_graph); | |||||
| tf.Context.restore_mode(); | |||||
| var forward_function_name = $"{_FORWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"; | |||||
| var backward_function_attr = new Dictionary<string, string>(); | |||||
| backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||||
| backwards_graph.Inputs = gradients_wrt_outputs; | |||||
| backwards_graph.Outputs = gradients_wrt_inputs; | |||||
| var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); | |||||
| var forward_function_attr = new Dictionary<string, string>(); | |||||
| forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; | |||||
| var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, | |||||
| _func_graph.Inputs, _func_graph.Outputs, forward_function_attr); | |||||
| return (forward_function, _func_graph, backward_function, null, 0); | |||||
| } | |||||
| public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -47,6 +47,9 @@ namespace Tensorflow | |||||
| string description, | string description, | ||||
| SafeStatusHandle status); | SafeStatusHandle status); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TF_FunctionSetAttrValueProto(IntPtr func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TF_FunctionName(IntPtr func); | public static extern IntPtr TF_FunctionName(IntPtr func); | ||||
| @@ -13,8 +13,6 @@ namespace Tensorflow.Gradients | |||||
| void RecordOperation(string op_type, | void RecordOperation(string op_type, | ||||
| Tensor[] input_tensors, | Tensor[] input_tensors, | ||||
| TapeTensor[] output_tensors, | TapeTensor[] output_tensors, | ||||
| long[] input_tensor_id, | |||||
| TF_DataType[] input_dtypes, | |||||
| Func<BackwardFunction> backward_function_getter); | Func<BackwardFunction> backward_function_getter); | ||||
| void VariableAccessed(ResourceVariable variable); | void VariableAccessed(ResourceVariable variable); | ||||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using static Tensorflow.tensorflow; | using static Tensorflow.tensorflow; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using System.Linq; | |||||
| namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
| { | { | ||||
| @@ -14,18 +15,19 @@ namespace Tensorflow.Gradients | |||||
| public void RecordOperation(string op_type, | public void RecordOperation(string op_type, | ||||
| Tensor[] input_tensors, | Tensor[] input_tensors, | ||||
| TapeTensor[] output_tensors, | TapeTensor[] output_tensors, | ||||
| long[] input_tensor_id, | |||||
| TF_DataType[] input_dtypes, | |||||
| Func<BackwardFunction> backward_function_getter) | Func<BackwardFunction> backward_function_getter) | ||||
| { | { | ||||
| if (!ShouldRecord(input_tensor_id, input_dtypes)) | |||||
| var input_ids = input_tensors.Select(x => x.Id).ToArray(); | |||||
| var input_dtypes = input_tensors.Select(x => x.dtype).ToArray(); | |||||
| if (!ShouldRecord(input_ids, input_dtypes)) | |||||
| { | { | ||||
| return; | return; | ||||
| } | } | ||||
| long op_id = next_op_id_++; | long op_id = next_op_id_++; | ||||
| var ids = new List<long>(input_tensor_id.Length); | |||||
| foreach (var i in input_tensor_id) | |||||
| var ids = new List<long>(input_ids.Length); | |||||
| foreach (var i in input_ids) | |||||
| { | { | ||||
| tensor_usage_[i]++; | tensor_usage_[i]++; | ||||
| ids.Add(i); | ids.Add(i); | ||||
| @@ -17,6 +17,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Graphs; | |||||
| using Tensorflow.Operations.ControlFlows; | using Tensorflow.Operations.ControlFlows; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -39,7 +40,14 @@ namespace Tensorflow | |||||
| // If src_graph is a _FuncGraph (i.e. a function body), gather it and all | // If src_graph is a _FuncGraph (i.e. a function body), gather it and all | ||||
| // ancestor graphs. This is necessary for correctly handling captured values. | // ancestor graphs. This is necessary for correctly handling captured values. | ||||
| var func_graphs = new List<FuncGraph>(); | |||||
| var curr_graph = src_graph; | var curr_graph = src_graph; | ||||
| if (src_graph is FuncGraph func_graph) | |||||
| { | |||||
| func_graphs.append(func_graph); | |||||
| curr_graph = func_graph.OuterGraph; | |||||
| } | |||||
| if (stop_gradients == null) | if (stop_gradients == null) | ||||
| stop_gradients = new Tensor[0]; | stop_gradients = new Tensor[0]; | ||||
| @@ -84,7 +92,7 @@ namespace Tensorflow | |||||
| var to_ops = ys.Select(x => x.op).ToList(); | var to_ops = ys.Select(x => x.op).ToList(); | ||||
| var from_ops = xs.Select(x => x.op).ToList(); | var from_ops = xs.Select(x => x.op).ToList(); | ||||
| var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); | var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); | ||||
| (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs); | |||||
| (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs , xs); | |||||
| // Add the initial gradients for the ys. | // Add the initial gradients for the ys. | ||||
| foreach (var (y, grad_y) in zip(ys, grad_ys)) | foreach (var (y, grad_y) in zip(ys, grad_ys)) | ||||
| @@ -258,11 +266,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| var new_grad_ys = new List<Tensor>(); | var new_grad_ys = new List<Tensor>(); | ||||
| for (int i = 0; i < grad_ys.Length; i++) | |||||
| foreach(var (i, (y, grad_y)) in enumerate(zip(ys, grad_ys))) | |||||
| { | { | ||||
| var grad_y = grad_ys[i]; | |||||
| var y = ys[i]; | |||||
| _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops); | _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops); | ||||
| if (grad_y == null) | if (grad_y == null) | ||||
| @@ -272,8 +277,17 @@ namespace Tensorflow | |||||
| var shape = array_ops.shape(y); | var shape = array_ops.shape(y); | ||||
| var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}"); | var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}"); | ||||
| var fill = gen_array_ops.fill(shape, constant); | var fill = gen_array_ops.fill(shape, constant); | ||||
| new_grad_ys.Add(fill); | |||||
| new_grad_ys.append(fill); | |||||
| continue; | |||||
| } | } | ||||
| if (y.dtype.is_floating() || y.dtype.is_integer()) | |||||
| { | |||||
| } | |||||
| // Create a grad_y tensor in the name scope of the gradient. | |||||
| new_grad_ys.append(array_ops.identity(grad_y, name: $"grad_ys_{i}")); | |||||
| } | } | ||||
| return new_grad_ys.ToArray(); | return new_grad_ys.ToArray(); | ||||
| @@ -294,7 +308,11 @@ namespace Tensorflow | |||||
| /// <param name="colocate_gradients_with_ops"></param> | /// <param name="colocate_gradients_with_ops"></param> | ||||
| /// <param name="func_graphs"></param> | /// <param name="func_graphs"></param> | ||||
| /// <param name="xs"></param> | /// <param name="xs"></param> | ||||
| private static (Operation[], Dictionary<string, int>, ControlFlowState) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs) | |||||
| private static (Operation[], Dictionary<string, int>, ControlFlowState) _PendingCount(List<Operation> to_ops, | |||||
| List<Operation> from_ops, | |||||
| bool colocate_gradients_with_ops, | |||||
| List<FuncGraph> func_graphs, | |||||
| Tensor[] xs) | |||||
| { | { | ||||
| // Mark reachable ops from from_ops. | // Mark reachable ops from from_ops. | ||||
| var reached_ops = new List<Operation>(); | var reached_ops = new List<Operation>(); | ||||
| @@ -511,7 +529,7 @@ namespace Tensorflow | |||||
| /// <param name="from_ops"></param> | /// <param name="from_ops"></param> | ||||
| /// <param name="reached_ops"></param> | /// <param name="reached_ops"></param> | ||||
| /// <param name="func_graphs"></param> | /// <param name="func_graphs"></param> | ||||
| private static void _MarkReachedOps(List<Operation> from_ops, List<Operation> reached_ops, List<object> func_graphs) | |||||
| private static void _MarkReachedOps(List<Operation> from_ops, List<Operation> reached_ops, List<FuncGraph> func_graphs) | |||||
| { | { | ||||
| Queue<Operation> queue = new Queue<Operation>(from_ops); | Queue<Operation> queue = new Queue<Operation>(from_ops); | ||||
| while (queue.Count > 0) | while (queue.Count > 0) | ||||
| @@ -538,7 +556,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="t"></param> | /// <param name="t"></param> | ||||
| /// <param name="func_graphs"></param> | /// <param name="func_graphs"></param> | ||||
| private static Operation[] _Consumers(Tensor t, List<object> func_graphs) | |||||
| private static Operation[] _Consumers(Tensor t, List<FuncGraph> func_graphs) | |||||
| { | { | ||||
| return t.consumers(); | return t.consumers(); | ||||
| } | } | ||||
| @@ -647,7 +665,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| private static bool IsTrainable(Tensor tensor) | |||||
| public static bool IsTrainable(Tensor tensor) | |||||
| { | { | ||||
| var dtype = tensor.dtype.as_base_dtype(); | var dtype = tensor.dtype.as_base_dtype(); | ||||
| return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64, | return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64, | ||||
| @@ -18,10 +18,11 @@ namespace Tensorflow.Graphs | |||||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
| var func_handle = graph.ToGraph(opers, | var func_handle = graph.ToGraph(opers, | ||||
| new Operation[] { input }, | |||||
| new Operation[] { output }, | |||||
| new[] { input }, | |||||
| new[] { output }, | |||||
| null); | null); | ||||
| } | } | ||||
| return (Tensor input) => | return (Tensor input) => | ||||
| { | { | ||||
| @@ -48,11 +49,11 @@ namespace Tensorflow.Graphs | |||||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
| var func_handle = graph.ToGraph(opers, | var func_handle = graph.ToGraph(opers, | ||||
| new Operation[] { input1, input2 }, | |||||
| new Operation[] { output }, | |||||
| new[] { input1, input2 }, | |||||
| new[] { output }, | |||||
| null); | null); | ||||
| } | } | ||||
| return (Tensor a, Tensor b) => | return (Tensor a, Tensor b) => | ||||
| { | { | ||||
| var result = tf.Runner.TFE_Execute(tf.Context, | var result = tf.Runner.TFE_Execute(tf.Context, | ||||
| @@ -3,6 +3,7 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Tensorflow.Functions; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Graphs | namespace Tensorflow.Graphs | ||||
| @@ -10,10 +11,10 @@ namespace Tensorflow.Graphs | |||||
| [AllowChangingInputArguments] | [AllowChangingInputArguments] | ||||
| public sealed class AutoGraphAttribute : OnMethodBoundaryAspect | public sealed class AutoGraphAttribute : OnMethodBoundaryAspect | ||||
| { | { | ||||
| FuncGraph graph; | |||||
| ConcreteFunction function; | |||||
| Tensors originalInputs; | Tensors originalInputs; | ||||
| string func_name; | string func_name; | ||||
| static Dictionary<string, Func<Tensors, Tensors>> functions = new Dictionary<string, Func<Tensors, Tensors>>(); | |||||
| static Dictionary<string, ConcreteFunction> functions = new Dictionary<string, ConcreteFunction>(); | |||||
| public override void OnEntry(MethodExecutionArgs args) | public override void OnEntry(MethodExecutionArgs args) | ||||
| { | { | ||||
| @@ -21,22 +22,24 @@ namespace Tensorflow.Graphs | |||||
| if (functions.ContainsKey(func_name)) | if (functions.ContainsKey(func_name)) | ||||
| { | { | ||||
| function = functions[func_name]; | |||||
| if (args.Arguments[0] is Tensors tensor_inputs) | if (args.Arguments[0] is Tensors tensor_inputs) | ||||
| args.ReturnValue = functions[func_name](tensor_inputs.ToArray()); | |||||
| args.ReturnValue = ConvertReturnValue(function.Invoke(tensor_inputs)); | |||||
| else | else | ||||
| args.ReturnValue = functions[func_name](args.Arguments.Select(x => x as Tensor).ToArray()); | |||||
| args.ReturnValue = ConvertReturnValue(function.Invoke(args.Arguments.Select(x => x as Tensor).ToArray())); | |||||
| args.FlowBehavior = FlowBehavior.Return; | args.FlowBehavior = FlowBehavior.Return; | ||||
| return; | return; | ||||
| } | } | ||||
| // make function as an Operation by autograph | // make function as an Operation by autograph | ||||
| graph = new FuncGraph(func_name); | |||||
| // need to restore mode when exits | |||||
| function = new ConcreteFunction(func_name); | |||||
| // convert to Tensors | // convert to Tensors | ||||
| if (args.Arguments[0] is Tensors inputs) | if (args.Arguments[0] is Tensors inputs) | ||||
| { | { | ||||
| originalInputs = inputs; | originalInputs = inputs; | ||||
| var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.TensorShape)).ToArray(); | |||||
| var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.TensorShape, name: "inputs")).ToArray(); | |||||
| args.Arguments[0] = new Tensors(new_inputs); | args.Arguments[0] = new Tensors(new_inputs); | ||||
| } | } | ||||
| else | else | ||||
| @@ -48,7 +51,7 @@ namespace Tensorflow.Graphs | |||||
| if (args.Arguments[i] is EagerTensor tensor) | if (args.Arguments[i] is EagerTensor tensor) | ||||
| { | { | ||||
| originalInputs[i] = tensor; | originalInputs[i] = tensor; | ||||
| args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape); | |||||
| args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape, name: "inputs"); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -56,58 +59,30 @@ namespace Tensorflow.Graphs | |||||
| public override void OnExit(MethodExecutionArgs args) | public override void OnExit(MethodExecutionArgs args) | ||||
| { | { | ||||
| var returnValue = mark_as_return(args.ReturnValue as Tensors); | |||||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||||
| if (args.ReturnValue is Tensors outputs) | if (args.ReturnValue is Tensors outputs) | ||||
| { | { | ||||
| if (args.Arguments[0] is Tensors inputs) | if (args.Arguments[0] is Tensors inputs) | ||||
| { | |||||
| graph.ToGraph(opers, | |||||
| inputs.Select(x => x.op).ToArray(), | |||||
| outputs.Select(x => x.op).ToArray(), | |||||
| null); | |||||
| } | |||||
| function.ToGraph(inputs, outputs); | |||||
| else | else | ||||
| { | |||||
| graph.ToGraph(opers, | |||||
| args.Arguments.Select(x => (x as Tensor).op).ToArray(), | |||||
| outputs.Select(x => x.op).ToArray(), | |||||
| null); | |||||
| } | |||||
| function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), outputs); | |||||
| } | } | ||||
| else | else | ||||
| { | |||||
| graph.ToGraph(opers, | |||||
| args.Arguments.Select(x => (x as Tensor).op).ToArray(), | |||||
| new Operation[] { (args.ReturnValue as Tensor).op }, | |||||
| null); | |||||
| } | |||||
| graph.Dispose(); | |||||
| function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor); | |||||
| Func<Tensors, Tensors> function = (x) => | |||||
| { | |||||
| var result = tf.Runner.TFE_Execute(tf.Context, | |||||
| tf.Context.DeviceName, | |||||
| func_name, | |||||
| x, | |||||
| null, | |||||
| 1); | |||||
| return result[0]; | |||||
| }; | |||||
| // cache function. | // cache function. | ||||
| function.ReturnType = args.ReturnValue.GetType(); | |||||
| functions[func_name] = function; | functions[func_name] = function; | ||||
| // run function | // run function | ||||
| args.ReturnValue = function(originalInputs); | |||||
| args.ReturnValue = ConvertReturnValue(function.Invoke(originalInputs)); | |||||
| } | } | ||||
| Tensor mark_as_return(Tensor tensor) | |||||
| object ConvertReturnValue(Tensors tensors) | |||||
| { | { | ||||
| return array_ops.identity(tensor); | |||||
| if (function.ReturnType == typeof(Tensor)) | |||||
| return (Tensor)tensors; | |||||
| else | |||||
| return tensors; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -10,13 +10,18 @@ namespace Tensorflow.Graphs | |||||
| /// </summary> | /// </summary> | ||||
| public class FuncGraph : Graph | public class FuncGraph : Graph | ||||
| { | { | ||||
| List<Operation> inputs; | |||||
| List<Operation> outputs; | |||||
| Graph outer_graph; | Graph outer_graph; | ||||
| public Graph OuterGraph => outer_graph; | |||||
| string func_name; | string func_name; | ||||
| // _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||||
| IntPtr func_handle; | IntPtr func_handle; | ||||
| public string FuncName => func_name; | public string FuncName => func_name; | ||||
| public Tensors Inputs { get; set; } | |||||
| public Tensors Outputs { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Construct a new FuncGraph. | /// Construct a new FuncGraph. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -29,8 +34,17 @@ namespace Tensorflow.Graphs | |||||
| as_default(); | as_default(); | ||||
| } | } | ||||
| public FuncGraph(IntPtr handle, string name) | |||||
| { | |||||
| outer_graph = ops.get_default_graph(); | |||||
| func_name = name; | |||||
| tf.Context.graph_mode(); | |||||
| as_default(); | |||||
| } | |||||
| public IntPtr ToGraph(Operation[] opers, | public IntPtr ToGraph(Operation[] opers, | ||||
| Operation[] inputs, Operation[] outputs, | |||||
| Tensor[] inputs, Tensor[] outputs, | |||||
| string[] output_names) | string[] output_names) | ||||
| { | { | ||||
| using var status = new Status(); | using var status = new Status(); | ||||
| @@ -40,9 +54,9 @@ namespace Tensorflow.Graphs | |||||
| opers.Length, | opers.Length, | ||||
| opers.Select(x => (IntPtr)x).ToArray(), | opers.Select(x => (IntPtr)x).ToArray(), | ||||
| inputs.Length, | inputs.Length, | ||||
| inputs.Select(x => new TF_Output(x, 0)).ToArray(), | |||||
| inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||||
| outputs.Length, | outputs.Length, | ||||
| outputs.Select(x => new TF_Output(x, 0)).ToArray(), | |||||
| outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||||
| output_names == null || output_names.Length == 0 ? null : output_names, | output_names == null || output_names.Length == 0 ? null : output_names, | ||||
| IntPtr.Zero, | IntPtr.Zero, | ||||
| null, | null, | ||||
| @@ -57,13 +71,18 @@ namespace Tensorflow.Graphs | |||||
| func_name = c_api.StringPiece(c_api.TF_FunctionName(func_handle)); | func_name = c_api.StringPiece(c_api.TF_FunctionName(func_handle)); | ||||
| Inputs = inputs; | |||||
| // mark_as_return | |||||
| Outputs = outputs.Select(x => array_ops.identity(x)).ToArray(); | |||||
| tf.Context.restore_mode(); | |||||
| return func_handle; | return func_handle; | ||||
| } | } | ||||
| protected override void DisposeManagedResources() | protected override void DisposeManagedResources() | ||||
| { | { | ||||
| base.DisposeManagedResources(); | base.DisposeManagedResources(); | ||||
| tf.Context.restore_mode(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -69,7 +69,7 @@ namespace Tensorflow | |||||
| throw new ValueError($"Could not find operation \"{operName}\" inside graph \"{_graph_key}\"."); | throw new ValueError($"Could not find operation \"{operName}\" inside graph \"{_graph_key}\"."); | ||||
| var defaultKey = tf.get_default_graph().graph_key; | var defaultKey = tf.get_default_graph().graph_key; | ||||
| if (graph_key != defaultKey) | |||||
| if (tf.get_default_graph().GetType().Name == "Graph" && graph_key != defaultKey) | |||||
| { | { | ||||
| //Console.WriteLine($"Current graph is not default graph."); | //Console.WriteLine($"Current graph is not default graph."); | ||||
| throw new RuntimeError($"Current graph is not default graph. Default Graph Key: {defaultKey}, Current Graph Key: {graph_key}"); | throw new RuntimeError($"Current graph is not default graph. Default Graph Key: {defaultKey}, Current Graph Key: {graph_key}"); | ||||
| @@ -218,6 +218,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| case nameof(Int32): | case nameof(Int32): | ||||
| return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); | return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); | ||||
| case nameof(Int64): | |||||
| return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); | |||||
| default: | default: | ||||
| return null; | return null; | ||||
| } | } | ||||
| @@ -235,6 +235,14 @@ namespace Tensorflow | |||||
| } | } | ||||
| var _op = tf.OpDefLib._apply_op_helper("Identity", name, new { input }); | var _op = tf.OpDefLib._apply_op_helper("Identity", name, new { input }); | ||||
| if (tf.Runner.MustRecordGradient()) | |||||
| { | |||||
| tf.Runner.RecordGradient("Identity", _op.inputs, new object[] | |||||
| { | |||||
| "T", _op.get_attr<TF_DataType>("T") | |||||
| }, _op.outputs); | |||||
| } | |||||
| return _op.output; | return _op.output; | ||||
| } | } | ||||
| @@ -632,8 +640,8 @@ namespace Tensorflow | |||||
| public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, | public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, | ||||
| int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, | int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, | ||||
| int shrink_axis_mask = 0, string name = null) | int shrink_axis_mask = 0, string name = null) | ||||
| => tf.Context.RunInAutoMode(() | |||||
| => tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new | |||||
| => tf.Context.RunInAutoMode2( | |||||
| () => tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new | |||||
| { | { | ||||
| shape, | shape, | ||||
| begin, | begin, | ||||
| @@ -645,8 +653,8 @@ namespace Tensorflow | |||||
| ellipsis_mask, | ellipsis_mask, | ||||
| new_axis_mask, | new_axis_mask, | ||||
| shrink_axis_mask | shrink_axis_mask | ||||
| }).output, () | |||||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| }).output, | |||||
| () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "StridedSliceGrad", name, | "StridedSliceGrad", name, | ||||
| null, | null, | ||||
| shape, begin, end, strides, dy, | shape, begin, end, strides, dy, | ||||
| @@ -654,8 +662,22 @@ namespace Tensorflow | |||||
| "end_mask", end_mask, | "end_mask", end_mask, | ||||
| "ellipsis_mask", ellipsis_mask, | "ellipsis_mask", ellipsis_mask, | ||||
| "new_axis_mask", new_axis_mask, | "new_axis_mask", new_axis_mask, | ||||
| "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||||
| shape, begin, end, strides, dy); | |||||
| "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||||
| (op) => | |||||
| { | |||||
| var attrs = new object[] | |||||
| { | |||||
| "T", op.get_attr<TF_DataType>("T"), | |||||
| "Index", op.get_attr<TF_DataType>("Index"), | |||||
| "begin_mask", op.get_attr<long>("begin_mask"), | |||||
| "end_mask", op.get_attr<long>("end_mask"), | |||||
| "ellipsis_mask", op.get_attr<long>("ellipsis_mask"), | |||||
| "new_axis_mask", op.get_attr<long>("new_axis_mask"), | |||||
| "shrink_axis_mask", op.get_attr<long>("shrink_axis_mask") | |||||
| }; | |||||
| tf.Runner.RecordGradient("StridedSliceGrad", op.inputs, attrs, op.outputs); | |||||
| }, | |||||
| new Tensors(shape, begin, end, strides, dy)); | |||||
| /// <summary> | /// <summary> | ||||
| /// Removes dimensions of size 1 from the shape of a tensor. | /// Removes dimensions of size 1 from the shape of a tensor. | ||||
| @@ -23,6 +23,8 @@ using Tensorflow.Gradients; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); | |||||
| public partial class tensorflow : ITensorFlowObject | public partial class tensorflow : ITensorFlowObject | ||||
| { | { | ||||
| public TF_DataType byte8 = TF_DataType.TF_UINT8; | public TF_DataType byte8 = TF_DataType.TF_UINT8; | ||||
| @@ -37,8 +39,6 @@ namespace Tensorflow | |||||
| public TF_DataType chars = TF_DataType.TF_STRING; | public TF_DataType chars = TF_DataType.TF_STRING; | ||||
| public TF_DataType @string = TF_DataType.TF_STRING; | public TF_DataType @string = TF_DataType.TF_STRING; | ||||
| public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); | |||||
| public Status Status; | public Status Status; | ||||
| public OpDefLibrary OpDefLib; | public OpDefLibrary OpDefLib; | ||||
| public Context Context; | public Context Context; | ||||
| @@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Engine | |||||
| // Record the gradient because custom-made ops don't go through the | // Record the gradient because custom-made ops don't go through the | ||||
| // code-gen'd eager call path | // code-gen'd eager call path | ||||
| var op_type = op.node_def.Name; | |||||
| var op_type = op.node_def.Op; | |||||
| tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); | tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); | ||||
| @@ -1,9 +1,9 @@ | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using static Tensorflow.KerasApi; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System; | using System; | ||||
| using System.Linq; | |||||
| namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
| { | { | ||||
| @@ -26,9 +26,20 @@ namespace Tensorflow.Keras.Layers | |||||
| var result = array_ops.reshape(inputs, shape.ToArray()); | var result = array_ops.reshape(inputs, shape.ToArray()); | ||||
| if (!tf.Context.executing_eagerly()) | if (!tf.Context.executing_eagerly()) | ||||
| // result = result.set_shape(compute_output_shape(inputs.shape)); | |||||
| throw new NotImplementedException(""); | |||||
| result.set_shape(compute_output_shape(inputs.shape)); | |||||
| return result; | return result; | ||||
| } | } | ||||
| TensorShape compute_output_shape(TensorShape input_shape) | |||||
| { | |||||
| if (input_shape.dims[0] == -1) | |||||
| { | |||||
| input_shape = input_shape.dims[0]; | |||||
| var output_shape = input_shape.concatenate(args.TargetShape.dims); | |||||
| return output_shape; | |||||
| } | |||||
| else | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Graphs; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.ManagedAPI | namespace TensorFlowNET.UnitTest.ManagedAPI | ||||
| @@ -36,7 +37,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
| /// <param name="a"></param> | /// <param name="a"></param> | ||||
| /// <param name="b"></param> | /// <param name="b"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| // [AutoGraph] | |||||
| [AutoGraph] | |||||
| Tensor Mul(Tensor a, Tensor b) | Tensor Mul(Tensor a, Tensor b) | ||||
| { | { | ||||
| return a * b; | return a * b; | ||||