From 651d5d210065a0a274d7240bb92e32f88b208845 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 31 Jan 2019 22:02:56 -0600 Subject: [PATCH] a bunch of changes for Gradients. --- src/TensorFlowNET.Core/APIs/tf.gradients.cs | 45 +++ src/TensorFlowNET.Core/Eager/Context.cs | 5 + .../Gradients/gradients_impl.py.cs | 286 ++++++++++++++++-- .../Operations/array_ops.py.cs | 42 +++ .../Operations/control_flow_ops.py.cs | 29 ++ src/TensorFlowNET.Core/Python.cs | 2 +- src/TensorFlowNET.Core/Tensors/TensorShape.cs | 10 + src/TensorFlowNET.Core/Tensors/constant_op.cs | 5 + src/TensorFlowNET.Core/Tensors/dtypes.cs | 5 + src/TensorFlowNET.Core/Train/Optimizer.cs | 6 +- src/TensorFlowNET.Core/ops.py.cs | 6 + src/TensorFlowNET.Core/tf.cs | 2 +- test/TensorFlowNET.UnitTest/GradientTest.cs | 20 ++ 13 files changed, 429 insertions(+), 34 deletions(-) create mode 100644 src/TensorFlowNET.Core/APIs/tf.gradients.cs create mode 100644 test/TensorFlowNET.UnitTest/GradientTest.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.gradients.cs b/src/TensorFlowNET.Core/APIs/tf.gradients.cs new file mode 100644 index 00000000..115b7fef --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.gradients.cs @@ -0,0 +1,45 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public static partial class tf + { + public static object gradients(Tensor[] ys, + Tensor[] xs, + Tensor[] grad_ys = null, + string name = "gradients", + bool colocate_gradients_with_ops = false, + bool gate_gradients = false, + int? aggregation_method = null, + Tensor[] stop_gradients = null) + { + return gradients_impl._GradientsHelper(ys, + xs, + grad_ys, + name, + colocate_gradients_with_ops, + gate_gradients, + stop_gradients: stop_gradients); + } + + public static object gradients(Tensor ys, + Tensor[] xs, + Tensor[] grad_ys = null, + string name = "gradients", + bool colocate_gradients_with_ops = false, + bool gate_gradients = false, + int? aggregation_method = null, + Tensor[] stop_gradients = null) + { + return gradients_impl._GradientsHelper(new Tensor[] { ys }, + xs, + grad_ys, + name, + colocate_gradients_with_ops, + gate_gradients, + stop_gradients: stop_gradients); + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/Context.cs b/src/TensorFlowNET.Core/Eager/Context.cs index f32790d8..5a3b35db 100644 --- a/src/TensorFlowNET.Core/Eager/Context.cs +++ b/src/TensorFlowNET.Core/Eager/Context.cs @@ -24,6 +24,11 @@ namespace Tensorflow.Eager c_api.TFE_DeleteContext(_handle); } + public bool executing_eagerly() + { + return false; + } + public static implicit operator IntPtr(Context ctx) { return ctx._handle; diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs index e884bdac..04e13823 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp.Core; +using System; using System.Collections.Generic; using System.Linq; using System.Text; @@ -8,9 +9,9 @@ namespace Tensorflow { public class gradients_impl { - public static void gradients(object ys, - object xs, - List grad_ys = null, + public static void gradients(Tensor[] ys, + Tensor[] xs, + Tensor[] grad_ys = null, string name = "gradients", bool colocate_gradients_with_ops = false, bool gate_gradients = false, @@ -19,13 +20,14 @@ namespace Tensorflow _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients); } - public static void _GradientsHelper(object ys, - object xs, - object grad_ys = null, + public static Tensor[] _GradientsHelper(Tensor[] ys, + Tensor[] xs, + Tensor[] grad_ys = null, string name = "gradients", bool colocate_gradients_with_ops = false, bool gate_gradients = false, - object stop_gradients = null, + int aggregation_method = 0, + Tensor[] stop_gradients = null, Graph src_graph = null) { if (src_graph == null) @@ -35,20 +37,14 @@ namespace Tensorflow // ancestor graphs. This is necessary for correctly handling captured values. var curr_graph = src_graph; - var ys1 = _AsList(ys); - var xs1 = _AsList(xs); - List grad_ys1 = null; - List stop_gradients1 = stop_gradients == null ? new List() : _AsList(stop_gradients); if (grad_ys == null) - grad_ys1 = ys1.Select(x => new Tensor(IntPtr.Zero)).ToList(); - else - grad_ys = _AsList(grad_ys); + grad_ys = new Tensor[ys.Length]; var all = new List(); - all.AddRange(ys1); - all.AddRange(xs1); - all.AddRange(stop_gradients1); - all.AddRange(grad_ys1); + all.AddRange(ys); + all.AddRange(xs); + all.AddRange(stop_gradients); + all.AddRange(grad_ys); Python.with(new ops.name_scope(name, "gradients", values: all), scope => { @@ -56,24 +52,209 @@ namespace Tensorflow // Get a uid for this call to gradients that can be used to help // cluster ops for compilation. var gradient_uid = ops.get_default_graph().unique_name("uid"); + grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid); + + /** + * The approach we take here is as follows: Create a list of all ops in the + * subgraph between the ys and xs. Visit these ops in reverse order of ids + * to ensure that when we visit an op the gradients w.r.t its outputs have + * been collected. Then aggregate these gradients if needed, call the op's + * gradient function, and add the generated gradients to the gradients for + * its input. + **/ // Initialize the pending count for ops in the connected subgraph from ys // to the xs. - var to_ops = ys1.Select(x => x.op).ToList(); - var from_ops = xs1.Select(x => x.op).ToList(); - var stop_gradient_ops = stop_gradients1.Select(x => x.op).ToList(); - _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs1); + var to_ops = ys.Select(x => x.op).ToList(); + var from_ops = xs.Select(x => x.op).ToList(); + var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); + (var reachable_to_ops, var pending_count, var loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs); + + // Iterate over the collected ops. + /** + * grads: op => list of gradients received on each output endpoint of the + * op. The gradients for each endpoint are initially collected as a list. + * When it is time to call the op's gradient function, for each endpoint we + * aggregate the list of received gradients into a Add() Operation if there + * is more than one. + **/ + var grads = new Dictionary(); + for(int i = 0; i < ys.Count(); i++) + { + (var y, var grad_y) = Python.zip(ys, grad_ys, i); + _SetGrad(grads, y, grad_y); + } + + // Initialize queue with to_ops. + var queue = new Queue(); + // Add the ops in 'to_ops' into the queue. + var to_ops_set = new List(); + foreach (var op in to_ops) + { + // 'ready' handles the case where one output gradient relies on + // another output's gradient. + bool ready = !pending_count.ContainsKey(op.Name) || pending_count[op.Name] == 0; + if(ready && !to_ops_set.Contains(op) && reachable_to_ops.Contains(op)) + { + to_ops_set.Add(op); + queue.Enqueue(op); + } + } + + var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs); + while(queue.Count > 0) + { + // generate gradient subgraph for op. + var op = queue.Dequeue(); + _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); + //if (loop_state != null) + //loop_state.EnterGradWhileContext(op, before: true); + var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); + + var is_partitioned_call = _IsPartitionedCall(op); + var is_func_call = false; + var has_out_grads = true; + if (has_out_grads && !stop_ops.Contains(op)) + { + if (is_func_call) + { + + } + else + { + // A grad_fn must be defined, either as a function or as None + // for ops that do not have gradients. + + } + } + } }); + + return null; + } + + private static bool _IsPartitionedCall(Operation op) + { + return op.OpType == "PartitionedCall" || op.OpType == "StatefulPartitionedCall"; + } + + private static Tensor[] _AggregatedGrads(Dictionary grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0) + { + var out_grads = _GetGrads(grads, op); + for(int i = 0; i < out_grads.Length; i++) + { + var out_grad = out_grads[i]; + if(loop_state != null) + { + + } + + // Grads have to be Tensors or IndexedSlices + + // Aggregate multiple gradients, and convert [] to None. + if(out_grad != null) + { + if(out_grad.Length < 2) + { + string used = "nop"; + return new Tensor[] { out_grad[0] }; + } + } + } + + return null; } /// - /// + /// The set of ops that terminate the gradient computation. /// - /// - /// + /// list of Operations. + /// list of Operations never to backprop through. + /// mapping from operation to number of backprop inputs. + /// list of Tensors. + /// The set of operations. + private static Operation[] _StopOps(List from_ops, List stop_gradient_ops, Dictionary pending_count, Tensor[] xs) + { + var stop_ops = new List(); + + foreach(var op in from_ops) + { + bool is_stop_op = true; + foreach(var inp in _NonEagerInputs(op, xs)) + { + if(pending_count.ContainsKey(op.Name) && pending_count[op.Name] > 0) + { + is_stop_op = false; + break; + } + } + if (is_stop_op) + stop_ops.Add(op); + } + stop_ops.AddRange(stop_gradient_ops.Where(x => !stop_ops.Contains(x))); + return stop_ops.ToArray(); + } + + private static Tensor[][] _GetGrads(Dictionary grads, Operation op) + { + if (grads.ContainsKey(op.Name)) + return grads[op.Name]; + else + return op.outputs.Select(x => new Tensor[0]).ToArray(); + } + + /// + /// Sets gradient "grad" in "grads" for tensor "t". + /// + /// + /// + /// + private static void _SetGrad(Dictionary grads, Tensor t, Tensor grad) + { + var op = t.op; + Tensor[][] op_grads = null; + if (!grads.ContainsKey(op.Name)) + { + op_grads = op.outputs.Select(x => new Tensor[1]).ToArray(); + grads[op.Name] = op_grads; + } + var t_grads = op_grads[t.value_index]; + t_grads[0] = grad; + } + + /// + /// Fill in default values for grad_ys. + /// + /// List of gradients, can contain None. + /// List of tensors. /// /// - private void _DefaultGradYs(List grad_ys, List ys, bool colocate_gradients_with_ops, string gradient_uid = "__unsupported__") + private static Tensor[] _DefaultGradYs(Tensor[] grad_ys, Tensor[] ys, bool colocate_gradients_with_ops, string gradient_uid = "__unsupported__") + { + var new_grad_ys = new List(); + + for(int i = 0; i < grad_ys.Length; i++) + { + var grad_y = grad_ys[i]; + var y = ys[i]; + + _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops); + + if(grad_y == null) + { + if (y.dtype.is_complex()) + throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); + var shape = array_ops.shape(y); + var constant = constant_op.constant(1.0, name: $"grad_ys_{i}"); + var fill = gen_array_ops.fill(shape, constant); + new_grad_ys.Add(fill); + } + } + + return new_grad_ys.ToArray(); + } + + private static void _maybe_colocate_with(Operation op, string gradient_uid, bool colocate_gradients_with_ops) { } @@ -88,12 +269,59 @@ namespace Tensorflow /// /// /// - private static void _PendingCount(List to_ops, List from_ops, bool colocate_gradients_with_ops, List func_graphs, List xs) + private static (Operation[], Dictionary, object) _PendingCount(List to_ops, List from_ops, bool colocate_gradients_with_ops, List func_graphs, Tensor[] xs) { - List reached_ops = new List(); + // Mark reachable ops from from_ops. + var reached_ops = new List(); _MarkReachedOps(from_ops, reached_ops, func_graphs); + // X in reached_ops iff X is reachable from from_ops by a path of zero or more + // backpropagatable tensors. + + var reachable_to_ops = to_ops.Where(x => reached_ops.Contains(x)).Select(x => x).ToArray(); + + var between_ops = new List(); + var between_op_list = new List(); + + Queue queue = new Queue(to_ops); + while(queue.Count > 0) + { + var op = queue.Dequeue(); + if (reached_ops.Contains(op)) + { + between_ops.Add(op); + between_op_list.Insert(between_op_list.Count, op); + // Clear the boolean so we won't add the inputs again. + reached_ops.Remove(op); + foreach (var inp in _NonEagerInputs(op, xs)) + queue.Enqueue((inp as Tensor).op); + } + } + // X in between_ops iff X is on a path of zero or more backpropagatable tensors + // between from_ops and to_ops + + // 'loop_state' is None if there are no while loops. + var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops); + + var pending_count = new Dictionary(); + foreach (var op in between_op_list) + { + foreach(Tensor x in _NonEagerInputs(op, xs)) + { + if (between_ops.Contains(x.op)) + if (pending_count.ContainsKey(x.op.Name)) + pending_count[x.op.Name] += 1; + else + pending_count[x.op.Name] = 1; + } + } + + return (reachable_to_ops.ToArray(), pending_count, loop_state); } + private static InputList _NonEagerInputs(Operation op, Tensor[] xs) + { + return op.inputs; + } /// /// Mark all ops reached from "from_ops" /// diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 14594755..6da5ad2a 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -52,5 +52,47 @@ namespace Tensorflow return gen_array_ops.fill(tShape, c, name); } } + + /// + /// Returns the shape of a tensor. + /// + /// A `Tensor` or `SparseTensor`. + /// A name for the operation (optional). + /// + /// (Optional) The specified output type of the operation + /// (`int32` or `int64`). Defaults to `tf.int32`. + /// + /// A `Tensor` of type `out_type`. + public static Tensor shape(Tensor input, string name = "", TF_DataType out_type = TF_DataType.TF_INT32) + { + return shape_internal(input, name, optimize: true, out_type: out_type); + } + + private static Tensor shape_internal(Tensor input, string name = "", bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) + { + Tensor result = null; + + Python.with(new ops.name_scope(name, "Shape", new Tensor[] { input }), scope => + { + name = scope; + + if (!tf.context.executing_eagerly()) + { + var input_tensor = ops.convert_to_tensor(input); + var input_shape = tensor_util.to_shape(input_tensor.shape); + if (optimize && input_shape.is_fully_defined()) + { + var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype()); + result = constant_op.constant(nd, name); + } + } + else + { + // result = gen_array_ops.shape(); + } + }); + + return result; + } } } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index 6b784239..913f56c4 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -52,5 +52,34 @@ namespace Tensorflow return result; } + + /// + /// Create the state for all the while loops involved in one gradients(). + /// + /// + /// + /// + public static object MaybeCreateControlFlowState(List between_op_list, List between_ops, bool colocate_gradients_with_ops) + { + object loop_state = null; + + foreach (var op in between_op_list) + { + if (IsLoopExit(op)) + { + if(loop_state == null) + { + // loop_state = ControlFlowState(); + } + } + } + + return loop_state; + } + + private static bool IsLoopExit(Operation op) + { + return op.OpType == "Exit" || op.OpType == "RefExit"; + } } } diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs index a1c377dd..b4358e58 100644 --- a/src/TensorFlowNET.Core/Python.cs +++ b/src/TensorFlowNET.Core/Python.cs @@ -53,7 +53,7 @@ namespace Tensorflow } } - public static (T, T) zip(T t1, T t2, int index = 0) where T : IList + public static (T, T) zip(IList t1, IList t2, int index = 0) { return (t1[index], t2[index]); } diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 7a5b0d88..15ff64eb 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -2,6 +2,7 @@ using NumSharp.Core; using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Tensorflow @@ -15,5 +16,14 @@ namespace Tensorflow { } + + /// + /// Returns True iff `self` is fully defined in every dimension. + /// + /// + public bool is_fully_defined() + { + return Dimensions != null; + } } } diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 1e331c0b..e1f930f2 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -21,6 +21,11 @@ namespace Tensorflow /// public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) { + if (tf.context.executing_eagerly()) + { + + } + Graph g = ops.get_default_graph(); var tensor_pb = tensor_util.make_tensor_proto(nd, verify_shape); var tensor_value = new AttrValue diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index a923240b..396fdd87 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -77,5 +77,10 @@ namespace Tensorflow (DataType)Enum.Parse(typeof(DataType), ((int)type - 100).ToString()) : type; } + + public static bool is_complex(this TF_DataType type) + { + return type == TF_DataType.TF_COMPLEX || type == TF_DataType.TF_COMPLEX64 || type == TF_DataType.TF_COMPLEX128; + } } } diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index 77f2d3a9..244eda50 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -68,7 +68,7 @@ namespace Tensorflow int? aggregation_method = null, GateGradientType gate_gradients = GateGradientType.GATE_OP, bool colocate_gradients_with_ops = false, - List grad_loss = null) + Tensor[] grad_loss = null) { int num_towers = 1; if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN) @@ -85,9 +85,9 @@ namespace Tensorflow } var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); - var var_refs = processors.Select(x => x.target()).ToList(); + var var_refs = processors.Select(x => x.target()).ToArray(); - gradients_impl.gradients(loss, var_refs, grad_ys: grad_loss, + gradients_impl.gradients(new Tensor[] { loss }, var_refs, grad_ys: grad_loss, gate_gradients: (gate_gradients == GateGradientType.GATE_OP), aggregation_method: aggregation_method, colocate_gradients_with_ops: colocate_gradients_with_ops); diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index dc4eaec5..a8708d30 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -278,5 +278,11 @@ namespace Tensorflow { return tf.Session(); } + + public static object get_gradient_function(Operation op) + { + if (op.inputs == null) return null; + return null; + } } } diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 8edc70f5..3d6f1be1 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -14,7 +14,7 @@ namespace Tensorflow public static TF_DataType float64 = TF_DataType.TF_DOUBLE; public static TF_DataType chars = TF_DataType.TF_STRING; - public static Context context; + public static Context context = new Context(new ContextOptions(), new Status()); public static Graph g = new Graph(); diff --git a/test/TensorFlowNET.UnitTest/GradientTest.cs b/test/TensorFlowNET.UnitTest/GradientTest.cs new file mode 100644 index 00000000..5fd1972d --- /dev/null +++ b/test/TensorFlowNET.UnitTest/GradientTest.cs @@ -0,0 +1,20 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class GradientTest + { + [TestMethod] + public void Gradients() + { + var a = tf.constant(0.0); + var b = 2.0 * a; + var g = tf.gradients(a + b, new Tensor[] { a, b }, stop_gradients: new Tensor[] { a, b }); + } + } +}