| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Operations.ControlFlows; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -82,6 +83,7 @@ namespace Tensorflow | |||
| var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); | |||
| var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs); | |||
| // Add the initial gradients for the ys. | |||
| foreach (var (y, grad_y) in zip(ys, grad_ys)) | |||
| _SetGrad(grads, y, grad_y); | |||
| @@ -103,12 +105,25 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| if(loop_state != null) | |||
| { | |||
| var loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set); | |||
| foreach(var y in loop_exits) | |||
| { | |||
| //if(IsTrainable(y)) | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs); | |||
| while (queue.Count > 0) | |||
| { | |||
| // generate gradient subgraph for op. | |||
| var op = queue.Dequeue(); | |||
| if(op.name == "rnn/while/basic_rnn_cell/Tanh") | |||
| { | |||
| } | |||
| _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); | |||
| //if (loop_state != null) | |||
| //loop_state.EnterGradWhileContext(op, before: true); | |||
| @@ -147,8 +162,8 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| // if (loop_state) | |||
| //loop_state.EnterGradWhileContext(op, before: false); | |||
| if (loop_state != null) | |||
| loop_state.EnterGradWhileContext(op, before: false); | |||
| if ((is_func_call || grad_fn != null) && has_out_grads) | |||
| { | |||
| @@ -164,7 +179,7 @@ namespace Tensorflow | |||
| // will use SymbolicGradient get a zero gradient. Gradient | |||
| // functions should ignore the gradient for other outputs. | |||
| if (loop_state != null) | |||
| ; | |||
| out_grads[i] = new List<Tensor> { loop_state.ZerosLike(op, i) }; | |||
| else | |||
| out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; | |||
| } | |||
| @@ -275,7 +290,7 @@ namespace Tensorflow | |||
| /// <param name="colocate_gradients_with_ops"></param> | |||
| /// <param name="func_graphs"></param> | |||
| /// <param name="xs"></param> | |||
| private static (Operation[], Dictionary<string, int>, object) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs) | |||
| private static (Operation[], Dictionary<string, int>, ControlFlowState) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs) | |||
| { | |||
| // Mark reachable ops from from_ops. | |||
| var reached_ops = new List<Operation>(); | |||
| @@ -308,6 +323,7 @@ namespace Tensorflow | |||
| // 'loop_state' is None if there are no while loops. | |||
| var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops); | |||
| // Initialize pending count for between ops. | |||
| var pending_count = new Dictionary<string, int>(); | |||
| foreach (var op in between_op_list) | |||
| { | |||
| @@ -550,7 +566,7 @@ namespace Tensorflow | |||
| Operation op, | |||
| Queue<Operation> queue, | |||
| Dictionary<string, int> pending_count, | |||
| object loop_state, | |||
| ControlFlowState loop_state, | |||
| Tensor[] xs) | |||
| { | |||
| foreach (var x in _NonEagerInputs(op, xs)) | |||
| @@ -564,14 +580,49 @@ namespace Tensorflow | |||
| if (loop_state != null && !ready) | |||
| { | |||
| ready = pending_count[x.op.name] > 0 && control_flow_util.IsLoopSwitch(x.op); | |||
| } | |||
| if (ready) | |||
| { | |||
| // if x is an exit without real gradient, defer processing them. | |||
| if (control_flow_util.IsLoopExit(x.op)) | |||
| { | |||
| var grad_state = loop_state.GetGradState(x.op, before: false); | |||
| grad_state.deferred_exits.append(x); | |||
| grad_state.pending_exits_count -= 1; | |||
| // We now have all the exits so process them. | |||
| if (grad_state.pending_exits_count == 0) | |||
| { | |||
| var has_not_none_grad = false; | |||
| foreach(var y in grad_state.deferred_exits) | |||
| { | |||
| if (_HasAnyNotNoneGrads(grads, y.op)) | |||
| { | |||
| has_not_none_grad = true; | |||
| queue.Enqueue(y.op); | |||
| } | |||
| else | |||
| grad_state.unused_exits.append(y); | |||
| } | |||
| if (has_not_none_grad) | |||
| { | |||
| // For an unused exit, if it has trainable outputs, backprop | |||
| // a zero gradient. Otherwise, just ignore it. | |||
| foreach (var y in grad_state.unused_exits) | |||
| { | |||
| if (IsTrainable(y)) | |||
| _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)); | |||
| queue.Enqueue(y.op); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| // All exits are "unused" so use None as gradient. | |||
| foreach (var y in grad_state.unused_exits) | |||
| queue.Enqueue(y.op); | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| @@ -581,6 +632,32 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private static bool IsTrainable(Tensor tensor) | |||
| { | |||
| var dtype = tensor.dtype.as_base_dtype(); | |||
| return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64, | |||
| dtypes.complex64, dtypes.complex128, | |||
| dtypes.resource, dtypes.variant}.Contains(dtype); | |||
| } | |||
| /// <summary> | |||
| /// Return true if op has real gradient. | |||
| /// </summary> | |||
| /// <param name="grads"></param> | |||
| /// <param name="op"></param> | |||
| /// <returns></returns> | |||
| private static bool _HasAnyNotNoneGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op) | |||
| { | |||
| var out_grads = _GetGrads(grads, op); | |||
| foreach(var out_grad in out_grads) | |||
| { | |||
| if (out_grad.Exists(g => g != null)) | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn) | |||
| { | |||
| scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; | |||
| @@ -589,6 +666,9 @@ namespace Tensorflow | |||
| private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op) | |||
| { | |||
| if (op.type == "While" || op.type == "StatelessWhile") | |||
| return; | |||
| if (grads.Count() != op.inputs._inputs.Count()) | |||
| throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + | |||
| $"inputs {op.inputs._inputs.Count()}"); | |||