From fed0550c12c56e50987ee544e0ee8ab0b15f6d2c Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 13 Jun 2019 21:48:40 -0500 Subject: [PATCH] fix rank_internal, add gradients_util --- TensorFlow.NET.sln | 6 - src/TensorFlowNET.Core/APIs/tf.gradients.cs | 4 +- .../Gradients/gradients_impl.py.cs | 485 +---------------- .../Gradients/gradients_util.cs | 505 ++++++++++++++++++ .../Operations/Operation.cs | 4 +- .../Operations/array_ops.py.cs | 22 +- src/TensorFlowNET.Core/Operations/math_ops.cs | 14 - .../TensorFlowNET.Core.csproj | 7 +- .../TextProcess/CnnTextClassification.cs | 2 - 9 files changed, 534 insertions(+), 515 deletions(-) create mode 100644 src/TensorFlowNET.Core/Gradients/gradients_util.cs diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 63b6ef6d..51125309 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -17,8 +17,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Keras.UnitTest", "test\Kera EndProject Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "TensorFlowNET.Examples.FSharp", "test\TensorFlowNET.Examples.FSharp\TensorFlowNET.Examples.FSharp.fsproj", "{62BC3801-F0D3-44A9-A0AC-712F40C8F961}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{92762DCB-64C8-41B4-BEF7-780A969CE68F}" -EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -53,10 +51,6 @@ Global {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Debug|Any CPU.Build.0 = Debug|Any CPU {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.ActiveCfg = Release|Any CPU {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.Build.0 = Release|Any CPU - {92762DCB-64C8-41B4-BEF7-780A969CE68F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {92762DCB-64C8-41B4-BEF7-780A969CE68F}.Debug|Any CPU.Build.0 = Debug|Any CPU - {92762DCB-64C8-41B4-BEF7-780A969CE68F}.Release|Any CPU.ActiveCfg = Release|Any CPU - {92762DCB-64C8-41B4-BEF7-780A969CE68F}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/TensorFlowNET.Core/APIs/tf.gradients.cs b/src/TensorFlowNET.Core/APIs/tf.gradients.cs index 77491d55..2e7c68c3 100644 --- a/src/TensorFlowNET.Core/APIs/tf.gradients.cs +++ b/src/TensorFlowNET.Core/APIs/tf.gradients.cs @@ -15,7 +15,7 @@ namespace Tensorflow int? aggregation_method = null, Tensor[] stop_gradients = null) { - return gradients_impl._GradientsHelper(ys, + return gradients_util._GradientsHelper(ys, xs, grad_ys, name, @@ -33,7 +33,7 @@ namespace Tensorflow int? aggregation_method = null, Tensor[] stop_gradients = null) { - return gradients_impl._GradientsHelper(new Tensor[] { ys }, + return gradients_util._GradientsHelper(new Tensor[] { ys }, xs, grad_ys, name, diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs index 18151ac5..8ad4b44e 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs @@ -1,5 +1,4 @@ -using NumSharp; -using System; +using System; using System.Collections.Generic; using System.Linq; using System.Text; @@ -18,487 +17,7 @@ namespace Tensorflow bool gate_gradients = false, int? aggregation_method = null) { - return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients); - } - - public static Tensor[] _GradientsHelper(Tensor[] ys, - Tensor[] xs, - Tensor[] grad_ys = null, - string name = "gradients", - bool colocate_gradients_with_ops = false, - bool gate_gradients = false, - int aggregation_method = 0, - Tensor[] stop_gradients = null, - Graph src_graph = null) - { - if (src_graph == null) - src_graph = ops.get_default_graph(); - - // If src_graph is a _FuncGraph (i.e. a function body), gather it and all - // ancestor graphs. This is necessary for correctly handling captured values. - var curr_graph = src_graph; - - if (stop_gradients == null) - stop_gradients = new Tensor[0]; - if (grad_ys == null) - grad_ys = new Tensor[ys.Length]; - - // Iterate over the collected ops. - /** - * grads: op => list of gradients received on each output endpoint of the - * op. The gradients for each endpoint are initially collected as a list. - * When it is time to call the op's gradient function, for each endpoint we - * aggregate the list of received gradients into a Add() Operation if there - * is more than one. - **/ - var grads = new Dictionary(); - - with(ops.name_scope(name, "gradients", - values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope => - { - string grad_scope = scope; - // Get a uid for this call to gradients that can be used to help - // cluster ops for compilation. - var gradient_uid = ops.get_default_graph().unique_name("uid"); - ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y"); - xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true); - grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid); - - /** - * The approach we take here is as follows: Create a list of all ops in the - * subgraph between the ys and xs. Visit these ops in reverse order of ids - * to ensure that when we visit an op the gradients w.r.t its outputs have - * been collected. Then aggregate these gradients if needed, call the op's - * gradient function, and add the generated gradients to the gradients for - * its input. - **/ - - // Initialize the pending count for ops in the connected subgraph from ys - // to the xs. - var to_ops = ys.Select(x => x.op).ToList(); - var from_ops = xs.Select(x => x.op).ToList(); - var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); - (var reachable_to_ops, var pending_count, var loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs); - - foreach(var (y, grad_y) in Python.zip(ys, grad_ys)) - _SetGrad(grads, y, grad_y); - - // Initialize queue with to_ops. - var queue = new Queue(); - // Add the ops in 'to_ops' into the queue. - var to_ops_set = new List(); - foreach (var op in to_ops) - { - // 'ready' handles the case where one output gradient relies on - // another output's gradient. - if (!pending_count.ContainsKey(op.name)) - pending_count[op.name] = 0; - bool ready = pending_count[op.name] == 0; - if(ready && !to_ops_set.Contains(op) && reachable_to_ops.Contains(op)) - { - to_ops_set.Add(op); - queue.Enqueue(op); - } - } - - var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs); - while(queue.Count > 0) - { - // generate gradient subgraph for op. - var op = queue.Dequeue(); - _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); - //if (loop_state != null) - //loop_state.EnterGradWhileContext(op, before: true); - var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); - - Tensor[] in_grads = null; - var is_partitioned_call = _IsPartitionedCall(op); - var is_func_call = false; - var has_out_grads = true; - if (has_out_grads && !stop_ops.Contains(op)) - { - if (is_func_call) - { - - } - else - { - // A grad_fn must be defined, either as a function or as None - // for ops that do not have gradients. - var grad_fn = ops.get_gradient_function(op); - - foreach(var (i, out_grad) in enumerate(out_grads)) - { - if(out_grad == null) - { - if (loop_state != null) - ; - else - out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i); - } - } - - with(ops.name_scope(op.name + "_grad"), scope1 => - { - string name1 = scope1; - if (grad_fn != null) - { - in_grads = _MaybeCompile(grad_scope, op, out_grads, null, grad_fn); - _VerifyGeneratedGradients(in_grads, op); - } - - if (gate_gradients && in_grads.Count(x => x != null) > 1) - { - ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); - in_grads = control_flow_ops.tuple(in_grads); - } - }); - } - } - else - { - in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; - } - - var inputs = _NonEagerInputs(op, xs).ToList(); - foreach (var (t_in, in_grad) in zip(inputs, in_grads)) - { - if(in_grad != null) - { - if(in_grad is Tensor && t_in.dtype != TF_DataType.TF_RESOURCE) - { - in_grad.shape = t_in.shape; - } - - _SetGrad(grads, t_in, in_grad); - } - } - - // Update pending count for the inputs of op and enqueue ready ops. - _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs); - } - }); - - return xs.Select(x => _GetGrad(grads, x)).ToArray(); - } - - /// - /// Update pending count for the inputs of op and enqueue ready ops. - /// - /// - /// - /// - /// - /// - /// - private static void _UpdatePendingAndEnqueueReady(Dictionary grads, - Operation op, - Queue queue, - Dictionary pending_count, - object loop_state, - Tensor[] xs) - { - foreach(var x in _NonEagerInputs(op, xs)) - { - if (!pending_count.ContainsKey(x.op.name)) - pending_count[x.op.name] = 0; - - pending_count[x.op.name] -= 1; - - var ready = pending_count[x.op.name] == 0; - - if(loop_state != null && !ready) - { - - } - - if (ready) - { - if (control_flow_util.IsLoopExit(x.op)) - { - - } - else - { - queue.Enqueue(x.op); - } - } - } - } - - private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op) - { - if (grads.Count() != op.inputs._inputs.Count()) - throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + - $"inputs {op.inputs._inputs.Count()}"); - } - - private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func grad_fn) - { - scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; - return grad_fn(op, out_grads); - } - - private static bool _IsPartitionedCall(Operation op) - { - return op.OpType == "PartitionedCall" || op.OpType == "StatefulPartitionedCall"; - } - - private static Tensor[] _AggregatedGrads(Dictionary grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0) - { - var out_grads = _GetGrads(grads, op); - var return_grads = new Tensor[out_grads.Length]; - - foreach(var (i, out_grad) in enumerate(out_grads)) - { - if (loop_state != null) - { - - } - - // Aggregate multiple gradients, and convert [] to None. - if (out_grad != null) - { - if (out_grad.Length < 2) - { - string used = "nop"; - return_grads[i] = out_grad[0]; - } - } - } - - return return_grads; - } - - /// - /// The set of ops that terminate the gradient computation. - /// - /// list of Operations. - /// list of Operations never to backprop through. - /// mapping from operation to number of backprop inputs. - /// list of Tensors. - /// The set of operations. - private static Operation[] _StopOps(List from_ops, List stop_gradient_ops, Dictionary pending_count, Tensor[] xs) - { - var stop_ops = new List(); - - foreach(var op in from_ops) - { - bool is_stop_op = true; - foreach(var inp in _NonEagerInputs(op, xs)) - { - if (!pending_count.ContainsKey(inp.op.name)) - pending_count[inp.op.name] = 0; - - if (pending_count[inp.op.name] > 0) - { - is_stop_op = false; - break; - } - } - if (is_stop_op) - stop_ops.Insert(0, op); - } - stop_ops.AddRange(stop_gradient_ops.Where(x => !stop_ops.Contains(x))); - return stop_ops.ToArray(); - } - - private static Tensor _GetGrad(Dictionary grads, Tensor t) - { - var op = t.op; - if (!grads.ContainsKey(op.name)) - return null; - Tensor[][] op_grads = grads[op.name]; - var t_grad = op_grads[t.value_index]; - return t_grad[0]; - } - - private static Tensor[][] _GetGrads(Dictionary grads, Operation op) - { - if (grads.ContainsKey(op.name)) - return grads[op.name]; - else - return op.outputs.Select(x => new Tensor[0]).ToArray(); - } - - /// - /// Sets gradient "grad" in "grads" for tensor "t". - /// - /// - /// - /// - private static void _SetGrad(Dictionary grads, Tensor t, Tensor grad) - { - var op = t.op; - Tensor[][] op_grads = grads.ContainsKey(op.name) ? grads[op.name] : null; - if (op_grads == null) - { - op_grads = op.outputs.Select(x => new Tensor[1]).ToArray(); - grads[op.name] = op_grads; - } - var t_grads = op_grads[t.value_index]; - t_grads[0] = grad; - } - - /// - /// Fill in default values for grad_ys. - /// - /// List of gradients, can contain None. - /// List of tensors. - /// - /// - private static Tensor[] _DefaultGradYs(Tensor[] grad_ys, Tensor[] ys, bool colocate_gradients_with_ops, string gradient_uid = "__unsupported__") - { - var new_grad_ys = new List(); - - for(int i = 0; i < grad_ys.Length; i++) - { - var grad_y = grad_ys[i]; - var y = ys[i]; - - _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops); - - if(grad_y == null) - { - if (y.dtype.is_complex()) - throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); - var shape = array_ops.shape(y); - var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}"); - var fill = gen_array_ops.fill(shape, constant); - new_grad_ys.Add(fill); - } - } - - return new_grad_ys.ToArray(); - } - - private static void _maybe_colocate_with(Operation op, string gradient_uid, bool colocate_gradients_with_ops) - { - - } - - /// - /// Initialize the pending count for ops between two lists of Operations. - /// 'pending_count[op]' indicates the number of backprop inputs - /// to this operation. - /// - /// - /// - /// - /// - /// - private static (Operation[], Dictionary, object) _PendingCount(List to_ops, List from_ops, bool colocate_gradients_with_ops, List func_graphs, Tensor[] xs) - { - // Mark reachable ops from from_ops. - var reached_ops = new List(); - _MarkReachedOps(from_ops, reached_ops, func_graphs); - // X in reached_ops iff X is reachable from from_ops by a path of zero or more - // backpropagatable tensors. - - var reachable_to_ops = to_ops.Where(x => reached_ops.Contains(x)).Select(x => x).ToArray(); - - var between_ops = new List(); - var between_op_list = new List(); - - Queue queue = new Queue(to_ops); - while(queue.Count > 0) - { - var op = queue.Dequeue(); - if (reached_ops.Contains(op)) - { - between_ops.Add(op); - between_op_list.Insert(between_op_list.Count, op); - // Clear the boolean so we won't add the inputs again. - reached_ops.Remove(op); - foreach (var inp in _NonEagerInputs(op, xs)) - queue.Enqueue(inp.op); - } - } - // X in between_ops iff X is on a path of zero or more backpropagatable tensors - // between from_ops and to_ops - - // 'loop_state' is None if there are no while loops. - var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops); - - var pending_count = new Dictionary(); - foreach (var op in between_op_list) - { - foreach(Tensor x in _NonEagerInputs(op, xs)) - { - if (between_ops.Contains(x.op)) - { - if (!pending_count.ContainsKey(x.op.name)) - pending_count[x.op.name] = 0; - - pending_count[x.op.name] += 1; - } - } - } - - return (reachable_to_ops.ToArray(), pending_count, loop_state); - } - - private static IEnumerable _NonEagerInputs(Operation op, Tensor[] xs) - { - for (int i = 0; i < op.inputs.Length; i++) - yield return op.inputs[i]; - } - - /// - /// Mark all ops reached from "from_ops" - /// - /// - /// - /// - private static void _MarkReachedOps(List from_ops, List reached_ops, List func_graphs) - { - Queue queue = new Queue(from_ops); - while (queue.Count > 0) - { - var op = queue.Dequeue(); - - if (!reached_ops.Contains(op)) - { - reached_ops.Add(op); - foreach (var output in op.outputs) - { - if (_IsBackpropagatable(output)) - { - var c = _Consumers(output, func_graphs).ToList(); - c.ForEach(x => queue.Enqueue(x)); - } - } - } - } - } - - private static bool _IsTrainable(Tensor tensor) - { - var dtype = tensor.dtype.as_base_dtype(); - return new TF_DataType[] {TF_DataType.TF_HALF, TF_DataType.TF_FLOAT, TF_DataType.TF_DOUBLE, - TF_DataType.TF_COMPLEX64, TF_DataType.TF_COMPLEX128, TF_DataType.TF_RESOURCE}.Contains(dtype); - } - private static bool _IsBackpropagatable(Tensor tensor) - { - if(_IsTrainable(tensor)) - { - return true; - } - else - { - var dtype = tensor.dtype.as_base_dtype(); - return new TF_DataType[] { TF_DataType.TF_BFLOAT16, TF_DataType.TF_VARIANT }.Contains(dtype); - } - } - - /// - /// Returns the consumers of t, crossing closure boundaries where necessary. - /// - /// - /// - private static Operation[] _Consumers(Tensor t, List func_graphs) - { - return t.consumers(); + return gradients_util._GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients); } private static List _AsList(object ys) diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs new file mode 100644 index 00000000..c68fdfa3 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -0,0 +1,505 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using static Tensorflow.Python; + +namespace Tensorflow +{ + public class gradients_util + { + public static Tensor[] _GradientsHelper(Tensor[] ys, + Tensor[] xs, + Tensor[] grad_ys = null, + string name = "gradients", + bool colocate_gradients_with_ops = false, + bool gate_gradients = false, + int aggregation_method = 0, + Tensor[] stop_gradients = null, + Graph src_graph = null) + { + if (src_graph == null) + src_graph = ops.get_default_graph(); + + // If src_graph is a _FuncGraph (i.e. a function body), gather it and all + // ancestor graphs. This is necessary for correctly handling captured values. + var curr_graph = src_graph; + + if (stop_gradients == null) + stop_gradients = new Tensor[0]; + if (grad_ys == null) + grad_ys = new Tensor[ys.Length]; + + // Iterate over the collected ops. + /** + * grads: op => list of gradients received on each output endpoint of the + * op. The gradients for each endpoint are initially collected as a list. + * When it is time to call the op's gradient function, for each endpoint we + * aggregate the list of received gradients into a Add() Operation if there + * is more than one. + **/ + var grads = new Dictionary>>(); + + with(ops.name_scope(name, "gradients", + values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope => + { + string grad_scope = scope; + // Get a uid for this call to gradients that can be used to help + // cluster ops for compilation. + var gradient_uid = ops.get_default_graph().unique_name("uid"); + ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y"); + xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true); + grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid); + + /** + * The approach we take here is as follows: Create a list of all ops in the + * subgraph between the ys and xs. Visit these ops in reverse order of ids + * to ensure that when we visit an op the gradients w.r.t its outputs have + * been collected. Then aggregate these gradients if needed, call the op's + * gradient function, and add the generated gradients to the gradients for + * its input. + **/ + + // Initialize the pending count for ops in the connected subgraph from ys + // to the xs. + var to_ops = ys.Select(x => x.op).ToList(); + var from_ops = xs.Select(x => x.op).ToList(); + var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); + (var reachable_to_ops, var pending_count, var loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs); + + foreach (var (y, grad_y) in zip(ys, grad_ys)) + _SetGrad(grads, y, grad_y); + + // Initialize queue with to_ops. + var queue = new Queue(); + // Add the ops in 'to_ops' into the queue. + var to_ops_set = new List(); + foreach (var op in to_ops) + { + // 'ready' handles the case where one output gradient relies on + // another output's gradient. + if (!pending_count.ContainsKey(op.name)) + pending_count[op.name] = 0; + bool ready = pending_count[op.name] == 0; + if (ready && !to_ops_set.Contains(op) && reachable_to_ops.Contains(op)) + { + to_ops_set.Add(op); + queue.Enqueue(op); + } + } + + var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs); + while (queue.Count > 0) + { + // generate gradient subgraph for op. + var op = queue.Dequeue(); + if(op.name == "embedding/ExpandDims") + { + + } + _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); + //if (loop_state != null) + //loop_state.EnterGradWhileContext(op, before: true); + var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); + + Tensor[] in_grads = null; + var is_partitioned_call = _IsPartitionedCall(op); + var is_func_call = false; + var has_out_grads = true; + if (has_out_grads && !stop_ops.Contains(op)) + { + if (is_func_call) + { + + } + else + { + // A grad_fn must be defined, either as a function or as None + // for ops that do not have gradients. + var grad_fn = ops.get_gradient_function(op); + + foreach (var (i, out_grad) in enumerate(out_grads)) + { + if (out_grad == null) + { + if (loop_state != null) + ; + else + out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i); + } + } + + with(ops.name_scope(op.name + "_grad"), scope1 => + { + string name1 = scope1; + if (grad_fn != null) + { + in_grads = _MaybeCompile(grad_scope, op, out_grads, null, grad_fn); + _VerifyGeneratedGradients(in_grads, op); + } + + if (gate_gradients && in_grads.Count(x => x != null) > 1) + { + ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); + in_grads = control_flow_ops.tuple(in_grads); + } + }); + } + } + else + { + in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; + } + + var inputs = _NonEagerInputs(op, xs).ToList(); + foreach (var (t_in, in_grad) in zip(inputs, in_grads)) + { + if (in_grad != null) + { + if (in_grad is Tensor && t_in.dtype != TF_DataType.TF_RESOURCE) + { + in_grad.shape = t_in.shape; + } + + _SetGrad(grads, t_in, in_grad); + } + } + + // Update pending count for the inputs of op and enqueue ready ops. + _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs); + } + }); + + return xs.Select(x => _GetGrad(grads, x)).ToArray(); + } + + /// + /// Fill in default values for grad_ys. + /// + /// List of gradients, can contain None. + /// List of tensors. + /// + /// + private static Tensor[] _DefaultGradYs(Tensor[] grad_ys, Tensor[] ys, bool colocate_gradients_with_ops, string gradient_uid = "__unsupported__") + { + var new_grad_ys = new List(); + + for (int i = 0; i < grad_ys.Length; i++) + { + var grad_y = grad_ys[i]; + var y = ys[i]; + + _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops); + + if (grad_y == null) + { + if (y.dtype.is_complex()) + throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); + var shape = array_ops.shape(y); + var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}"); + var fill = gen_array_ops.fill(shape, constant); + new_grad_ys.Add(fill); + } + } + + return new_grad_ys.ToArray(); + } + + private static void _maybe_colocate_with(Operation op, string gradient_uid, bool colocate_gradients_with_ops) + { + + } + + /// + /// Initialize the pending count for ops between two lists of Operations. + /// 'pending_count[op]' indicates the number of backprop inputs + /// to this operation. + /// + /// + /// + /// + /// + /// + private static (Operation[], Dictionary, object) _PendingCount(List to_ops, List from_ops, bool colocate_gradients_with_ops, List func_graphs, Tensor[] xs) + { + // Mark reachable ops from from_ops. + var reached_ops = new List(); + _MarkReachedOps(from_ops, reached_ops, func_graphs); + // X in reached_ops iff X is reachable from from_ops by a path of zero or more + // backpropagatable tensors. + + var reachable_to_ops = to_ops.Where(x => reached_ops.Contains(x)).Select(x => x).ToArray(); + + var between_ops = new List(); + var between_op_list = new List(); + + Queue queue = new Queue(to_ops); + while (queue.Count > 0) + { + var op = queue.Dequeue(); + if (reached_ops.Contains(op)) + { + between_ops.Add(op); + between_op_list.Insert(between_op_list.Count, op); + // Clear the boolean so we won't add the inputs again. + reached_ops.Remove(op); + foreach (var inp in _NonEagerInputs(op, xs)) + queue.Enqueue(inp.op); + } + } + // X in between_ops iff X is on a path of zero or more backpropagatable tensors + // between from_ops and to_ops + + // 'loop_state' is None if there are no while loops. + var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops); + + var pending_count = new Dictionary(); + foreach (var op in between_op_list) + { + foreach (Tensor x in _NonEagerInputs(op, xs)) + { + if (between_ops.Contains(x.op)) + { + if (!pending_count.ContainsKey(x.op.name)) + pending_count[x.op.name] = 0; + + pending_count[x.op.name] += 1; + } + } + } + + return (reachable_to_ops.ToArray(), pending_count, loop_state); + } + + /// + /// Sets gradient "grad" in "grads" for tensor "t". + /// + /// + /// + /// + private static void _SetGrad(Dictionary>> grads, Tensor t, Tensor grad) + { + var op = t.op; + var op_grads = grads.ContainsKey(op.name) ? grads[op.name] : null; + if (op_grads == null) + { + op_grads = op.outputs.Select(x => new List()).ToList(); + grads[op.name] = op_grads; + } + var t_grads = op_grads[t.value_index]; + t_grads.Add(grad); + } + + private static IEnumerable _NonEagerInputs(Operation op, Tensor[] xs) + { + for (int i = 0; i < op.inputs.Length; i++) + yield return op.inputs[i]; + } + + private static Tensor[] _AggregatedGrads(Dictionary>> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0) + { + var out_grads = _GetGrads(grads, op); + var return_grads = new Tensor[out_grads.Count]; + + foreach (var (i, out_grad) in enumerate(out_grads)) + { + if (loop_state != null) + { + + } + + // Aggregate multiple gradients, and convert [] to None. + if (out_grad.Count > 0) + { + if (out_grad.Count < 2) + { + string used = "nop"; + if (out_grad.Count == 0) + { + throw new ValueError("_AggregatedGrads out_grad.Length == 0"); + } + + return_grads[i] = out_grad[0]; + } + } + else + { + return_grads[i] = null; + } + } + + return return_grads; + } + + /// + /// The set of ops that terminate the gradient computation. + /// + /// list of Operations. + /// list of Operations never to backprop through. + /// mapping from operation to number of backprop inputs. + /// list of Tensors. + /// The set of operations. + private static Operation[] _StopOps(List from_ops, List stop_gradient_ops, Dictionary pending_count, Tensor[] xs) + { + var stop_ops = new List(); + + foreach (var op in from_ops) + { + bool is_stop_op = true; + foreach (var inp in _NonEagerInputs(op, xs)) + { + if (!pending_count.ContainsKey(inp.op.name)) + pending_count[inp.op.name] = 0; + + if (pending_count[inp.op.name] > 0) + { + is_stop_op = false; + break; + } + } + if (is_stop_op) + stop_ops.Insert(0, op); + } + stop_ops.AddRange(stop_gradient_ops.Where(x => !stop_ops.Contains(x))); + return stop_ops.ToArray(); + } + + private static Tensor _GetGrad(Dictionary>> grads, Tensor t) + { + var op = t.op; + if (!grads.ContainsKey(op.name)) + return null; + var op_grads = grads[op.name]; + var t_grad = op_grads[t.value_index]; + return t_grad[0]; + } + + private static List> _GetGrads(Dictionary>> grads, Operation op) + { + if (grads.ContainsKey(op.name)) + return grads[op.name]; + else + return op.outputs.Select(x => new List()).ToList(); + } + + /// + /// Mark all ops reached from "from_ops" + /// + /// + /// + /// + private static void _MarkReachedOps(List from_ops, List reached_ops, List func_graphs) + { + Queue queue = new Queue(from_ops); + while (queue.Count > 0) + { + var op = queue.Dequeue(); + + if (!reached_ops.Contains(op)) + { + reached_ops.Add(op); + foreach (var output in op.outputs) + { + if (_IsBackpropagatable(output)) + { + var c = _Consumers(output, func_graphs).ToList(); + c.ForEach(x => queue.Enqueue(x)); + } + } + } + } + } + + /// + /// Returns the consumers of t, crossing closure boundaries where necessary. + /// + /// + /// + private static Operation[] _Consumers(Tensor t, List func_graphs) + { + return t.consumers(); + } + + private static bool _IsBackpropagatable(Tensor tensor) + { + if (_IsTrainable(tensor)) + { + return true; + } + else + { + var dtype = tensor.dtype.as_base_dtype(); + return new TF_DataType[] { TF_DataType.TF_BFLOAT16, TF_DataType.TF_VARIANT }.Contains(dtype); + } + } + + private static bool _IsTrainable(Tensor tensor) + { + var dtype = tensor.dtype.as_base_dtype(); + return new TF_DataType[] {TF_DataType.TF_HALF, TF_DataType.TF_FLOAT, TF_DataType.TF_DOUBLE, + TF_DataType.TF_COMPLEX64, TF_DataType.TF_COMPLEX128, TF_DataType.TF_RESOURCE}.Contains(dtype); + } + + private static bool _IsPartitionedCall(Operation op) + { + return op.OpType == "PartitionedCall" || op.OpType == "StatefulPartitionedCall"; + } + + /// + /// Update pending count for the inputs of op and enqueue ready ops. + /// + /// + /// + /// + /// + /// + /// + private static void _UpdatePendingAndEnqueueReady(Dictionary>> grads, + Operation op, + Queue queue, + Dictionary pending_count, + object loop_state, + Tensor[] xs) + { + foreach (var x in _NonEagerInputs(op, xs)) + { + if (!pending_count.ContainsKey(x.op.name)) + pending_count[x.op.name] = 0; + + pending_count[x.op.name] -= 1; + + var ready = pending_count[x.op.name] == 0; + + if (loop_state != null && !ready) + { + + } + + if (ready) + { + if (control_flow_util.IsLoopExit(x.op)) + { + + } + else + { + queue.Enqueue(x.op); + } + } + } + } + + private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func grad_fn) + { + scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; + return grad_fn(op, out_grads); + } + + private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op) + { + if (grads.Count() != op.inputs._inputs.Count()) + throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + + $"inputs {op.inputs._inputs.Count()}"); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 4b032dea..c5bd77f6 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -55,12 +55,14 @@ namespace Tensorflow public TF_DataType dtype => TF_DataType.DtInvalid; private Status status = new Status(); - public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); + public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); private NodeDef _node_def; +#if GRAPH_SERIALIZE [JsonIgnore] +#endif public NodeDef node_def { get diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 18f158fe..14f79bd7 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -127,8 +127,28 @@ namespace Tensorflow private static Tensor expand_dims_v2(Tensor input, int axis, string name = null) => gen_array_ops.expand_dims(input, axis, name); + /// + /// Returns the rank of a tensor. + /// + /// + /// + /// public static Tensor rank(Tensor input, string name = null) - => math_ops.rank_internal(input, name, optimize: true); + => rank_internal(input, name, optimize: true); + + public static Tensor rank_internal(Tensor input, string name = null, bool optimize = true) + { + return with(ops.name_scope(name, "Rank", new List { input }), scope => + { + name = scope; + var input_tensor = ops.convert_to_tensor(input); + var input_shape = tensor_util.to_shape(input_tensor.shape); + if (optimize && input_shape.NDim > -1) + return constant_op.constant(input_shape.NDim, dtype: tf.int32, name: name); + else + return gen_array_ops.rank(input, name); + }); + } /// /// Creates a tensor with all elements set to 1. diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index f37bd0dd..051feecf 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -429,20 +429,6 @@ namespace Tensorflow }); } - public static Tensor rank_internal(Tensor input, string name = null, bool optimize = true) - { - return with(ops.name_scope(name, "Rank", new List { input }), scope => - { - name = scope; - var input_tensor = ops.convert_to_tensor(input); - var input_shape = tensor_util.to_shape(input_tensor.shape); - if (optimize && input_shape.NDim == null) - return constant_op.constant(input_shape.NDim); - else - return gen_array_ops.rank(input, name); - }); - } - public static Tensor maximum(Tx x, Ty y, string name = null) => gen_math_ops.maximum(x, y, name: name); diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index af3b7843..1e532dc8 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -29,7 +29,7 @@ Docs: https://tensorflownet.readthedocs.io true - TRACE;DEBUG;GRAPH_SERIALIZE + TRACE;DEBUG @@ -48,7 +48,6 @@ Docs: https://tensorflownet.readthedocs.io - @@ -63,8 +62,4 @@ Docs: https://tensorflownet.readthedocs.io - - - - diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs index e53a6f3b..60b6d050 100644 --- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs @@ -308,8 +308,6 @@ namespace TensorFlowNET.Examples { var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); - var imported_graph = JsonConvert.SerializeObject(graph, new JsonSerializerSettings { Formatting = Formatting.Indented }); - return with(tf.Session(graph), sess => Train(sess, graph)); }