| @@ -17,6 +17,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Operations.ControlFlows; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -82,6 +83,7 @@ namespace Tensorflow | |||||
| var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); | var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); | ||||
| var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs); | var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs); | ||||
| // Add the initial gradients for the ys. | |||||
| foreach (var (y, grad_y) in zip(ys, grad_ys)) | foreach (var (y, grad_y) in zip(ys, grad_ys)) | ||||
| _SetGrad(grads, y, grad_y); | _SetGrad(grads, y, grad_y); | ||||
| @@ -103,12 +105,25 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| if(loop_state != null) | |||||
| { | |||||
| var loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set); | |||||
| foreach(var y in loop_exits) | |||||
| { | |||||
| //if(IsTrainable(y)) | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs); | var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs); | ||||
| while (queue.Count > 0) | while (queue.Count > 0) | ||||
| { | { | ||||
| // generate gradient subgraph for op. | // generate gradient subgraph for op. | ||||
| var op = queue.Dequeue(); | var op = queue.Dequeue(); | ||||
| if(op.name == "rnn/while/basic_rnn_cell/Tanh") | |||||
| { | |||||
| } | |||||
| _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); | _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); | ||||
| //if (loop_state != null) | //if (loop_state != null) | ||||
| //loop_state.EnterGradWhileContext(op, before: true); | //loop_state.EnterGradWhileContext(op, before: true); | ||||
| @@ -147,8 +162,8 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| // if (loop_state) | |||||
| //loop_state.EnterGradWhileContext(op, before: false); | |||||
| if (loop_state != null) | |||||
| loop_state.EnterGradWhileContext(op, before: false); | |||||
| if ((is_func_call || grad_fn != null) && has_out_grads) | if ((is_func_call || grad_fn != null) && has_out_grads) | ||||
| { | { | ||||
| @@ -164,7 +179,7 @@ namespace Tensorflow | |||||
| // will use SymbolicGradient get a zero gradient. Gradient | // will use SymbolicGradient get a zero gradient. Gradient | ||||
| // functions should ignore the gradient for other outputs. | // functions should ignore the gradient for other outputs. | ||||
| if (loop_state != null) | if (loop_state != null) | ||||
| ; | |||||
| out_grads[i] = new List<Tensor> { loop_state.ZerosLike(op, i) }; | |||||
| else | else | ||||
| out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; | out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; | ||||
| } | } | ||||
| @@ -275,7 +290,7 @@ namespace Tensorflow | |||||
| /// <param name="colocate_gradients_with_ops"></param> | /// <param name="colocate_gradients_with_ops"></param> | ||||
| /// <param name="func_graphs"></param> | /// <param name="func_graphs"></param> | ||||
| /// <param name="xs"></param> | /// <param name="xs"></param> | ||||
| private static (Operation[], Dictionary<string, int>, object) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs) | |||||
| private static (Operation[], Dictionary<string, int>, ControlFlowState) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs) | |||||
| { | { | ||||
| // Mark reachable ops from from_ops. | // Mark reachable ops from from_ops. | ||||
| var reached_ops = new List<Operation>(); | var reached_ops = new List<Operation>(); | ||||
| @@ -308,6 +323,7 @@ namespace Tensorflow | |||||
| // 'loop_state' is None if there are no while loops. | // 'loop_state' is None if there are no while loops. | ||||
| var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops); | var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops); | ||||
| // Initialize pending count for between ops. | |||||
| var pending_count = new Dictionary<string, int>(); | var pending_count = new Dictionary<string, int>(); | ||||
| foreach (var op in between_op_list) | foreach (var op in between_op_list) | ||||
| { | { | ||||
| @@ -550,7 +566,7 @@ namespace Tensorflow | |||||
| Operation op, | Operation op, | ||||
| Queue<Operation> queue, | Queue<Operation> queue, | ||||
| Dictionary<string, int> pending_count, | Dictionary<string, int> pending_count, | ||||
| object loop_state, | |||||
| ControlFlowState loop_state, | |||||
| Tensor[] xs) | Tensor[] xs) | ||||
| { | { | ||||
| foreach (var x in _NonEagerInputs(op, xs)) | foreach (var x in _NonEagerInputs(op, xs)) | ||||
| @@ -564,14 +580,49 @@ namespace Tensorflow | |||||
| if (loop_state != null && !ready) | if (loop_state != null && !ready) | ||||
| { | { | ||||
| ready = pending_count[x.op.name] > 0 && control_flow_util.IsLoopSwitch(x.op); | |||||
| } | } | ||||
| if (ready) | if (ready) | ||||
| { | { | ||||
| // if x is an exit without real gradient, defer processing them. | |||||
| if (control_flow_util.IsLoopExit(x.op)) | if (control_flow_util.IsLoopExit(x.op)) | ||||
| { | { | ||||
| var grad_state = loop_state.GetGradState(x.op, before: false); | |||||
| grad_state.deferred_exits.append(x); | |||||
| grad_state.pending_exits_count -= 1; | |||||
| // We now have all the exits so process them. | |||||
| if (grad_state.pending_exits_count == 0) | |||||
| { | |||||
| var has_not_none_grad = false; | |||||
| foreach(var y in grad_state.deferred_exits) | |||||
| { | |||||
| if (_HasAnyNotNoneGrads(grads, y.op)) | |||||
| { | |||||
| has_not_none_grad = true; | |||||
| queue.Enqueue(y.op); | |||||
| } | |||||
| else | |||||
| grad_state.unused_exits.append(y); | |||||
| } | |||||
| if (has_not_none_grad) | |||||
| { | |||||
| // For an unused exit, if it has trainable outputs, backprop | |||||
| // a zero gradient. Otherwise, just ignore it. | |||||
| foreach (var y in grad_state.unused_exits) | |||||
| { | |||||
| if (IsTrainable(y)) | |||||
| _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)); | |||||
| queue.Enqueue(y.op); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| // All exits are "unused" so use None as gradient. | |||||
| foreach (var y in grad_state.unused_exits) | |||||
| queue.Enqueue(y.op); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -581,6 +632,32 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| private static bool IsTrainable(Tensor tensor) | |||||
| { | |||||
| var dtype = tensor.dtype.as_base_dtype(); | |||||
| return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64, | |||||
| dtypes.complex64, dtypes.complex128, | |||||
| dtypes.resource, dtypes.variant}.Contains(dtype); | |||||
| } | |||||
| /// <summary> | |||||
| /// Return true if op has real gradient. | |||||
| /// </summary> | |||||
| /// <param name="grads"></param> | |||||
| /// <param name="op"></param> | |||||
| /// <returns></returns> | |||||
| private static bool _HasAnyNotNoneGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op) | |||||
| { | |||||
| var out_grads = _GetGrads(grads, op); | |||||
| foreach(var out_grad in out_grads) | |||||
| { | |||||
| if (out_grad.Exists(g => g != null)) | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn) | private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn) | ||||
| { | { | ||||
| scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; | scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; | ||||
| @@ -589,6 +666,9 @@ namespace Tensorflow | |||||
| private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op) | private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op) | ||||
| { | { | ||||
| if (op.type == "While" || op.type == "StatelessWhile") | |||||
| return; | |||||
| if (grads.Count() != op.inputs._inputs.Count()) | if (grads.Count() != op.inputs._inputs.Count()) | ||||
| throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + | throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + | ||||
| $"inputs {op.inputs._inputs.Count()}"); | $"inputs {op.inputs._inputs.Count()}"); | ||||