diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 15ad511b..163192ee 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Operations.ControlFlows; using static Tensorflow.Binding; namespace Tensorflow @@ -82,6 +83,7 @@ namespace Tensorflow var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs); + // Add the initial gradients for the ys. foreach (var (y, grad_y) in zip(ys, grad_ys)) _SetGrad(grads, y, grad_y); @@ -103,12 +105,25 @@ namespace Tensorflow } } + if(loop_state != null) + { + var loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set); + foreach(var y in loop_exits) + { + //if(IsTrainable(y)) + throw new NotImplementedException(""); + } + } + var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs); while (queue.Count > 0) { // generate gradient subgraph for op. var op = queue.Dequeue(); + if(op.name == "rnn/while/basic_rnn_cell/Tanh") + { + } _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); //if (loop_state != null) //loop_state.EnterGradWhileContext(op, before: true); @@ -147,8 +162,8 @@ namespace Tensorflow } } - // if (loop_state) - //loop_state.EnterGradWhileContext(op, before: false); + if (loop_state != null) + loop_state.EnterGradWhileContext(op, before: false); if ((is_func_call || grad_fn != null) && has_out_grads) { @@ -164,7 +179,7 @@ namespace Tensorflow // will use SymbolicGradient get a zero gradient. Gradient // functions should ignore the gradient for other outputs. if (loop_state != null) - ; + out_grads[i] = new List { loop_state.ZerosLike(op, i) }; else out_grads[i] = new List { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; } @@ -275,7 +290,7 @@ namespace Tensorflow /// /// /// - private static (Operation[], Dictionary, object) _PendingCount(List to_ops, List from_ops, bool colocate_gradients_with_ops, List func_graphs, Tensor[] xs) + private static (Operation[], Dictionary, ControlFlowState) _PendingCount(List to_ops, List from_ops, bool colocate_gradients_with_ops, List func_graphs, Tensor[] xs) { // Mark reachable ops from from_ops. var reached_ops = new List(); @@ -308,6 +323,7 @@ namespace Tensorflow // 'loop_state' is None if there are no while loops. var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops); + // Initialize pending count for between ops. var pending_count = new Dictionary(); foreach (var op in between_op_list) { @@ -550,7 +566,7 @@ namespace Tensorflow Operation op, Queue queue, Dictionary pending_count, - object loop_state, + ControlFlowState loop_state, Tensor[] xs) { foreach (var x in _NonEagerInputs(op, xs)) @@ -564,14 +580,49 @@ namespace Tensorflow if (loop_state != null && !ready) { - + ready = pending_count[x.op.name] > 0 && control_flow_util.IsLoopSwitch(x.op); } if (ready) { + // if x is an exit without real gradient, defer processing them. if (control_flow_util.IsLoopExit(x.op)) { - + var grad_state = loop_state.GetGradState(x.op, before: false); + grad_state.deferred_exits.append(x); + grad_state.pending_exits_count -= 1; + // We now have all the exits so process them. + if (grad_state.pending_exits_count == 0) + { + var has_not_none_grad = false; + foreach(var y in grad_state.deferred_exits) + { + if (_HasAnyNotNoneGrads(grads, y.op)) + { + has_not_none_grad = true; + queue.Enqueue(y.op); + } + else + grad_state.unused_exits.append(y); + } + if (has_not_none_grad) + { + // For an unused exit, if it has trainable outputs, backprop + // a zero gradient. Otherwise, just ignore it. + foreach (var y in grad_state.unused_exits) + { + if (IsTrainable(y)) + _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)); + queue.Enqueue(y.op); + } + } + else + { + // All exits are "unused" so use None as gradient. + foreach (var y in grad_state.unused_exits) + queue.Enqueue(y.op); + } + } } else { @@ -581,6 +632,32 @@ namespace Tensorflow } } + private static bool IsTrainable(Tensor tensor) + { + var dtype = tensor.dtype.as_base_dtype(); + return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64, + dtypes.complex64, dtypes.complex128, + dtypes.resource, dtypes.variant}.Contains(dtype); + } + + /// + /// Return true if op has real gradient. + /// + /// + /// + /// + private static bool _HasAnyNotNoneGrads(Dictionary>> grads, Operation op) + { + var out_grads = _GetGrads(grads, op); + foreach(var out_grad in out_grads) + { + if (out_grad.Exists(g => g != null)) + return true; + } + return false; + } + + private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func grad_fn) { scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; @@ -589,6 +666,9 @@ namespace Tensorflow private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op) { + if (op.type == "While" || op.type == "StatelessWhile") + return; + if (grads.Count() != op.inputs._inputs.Count()) throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + $"inputs {op.inputs._inputs.Count()}");