From ad250d0c796b12e496ee3a18998d46da48284547 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 3 Nov 2019 10:42:17 -0600 Subject: [PATCH] implement _SwitchGrad when merge_grad is not null. --- .../Gradients/control_flow_grad.cs | 19 ++- .../Gradients/gradients_util.cs | 111 +++++++++--------- .../Operations/Operation.cs | 5 +- 3 files changed, 67 insertions(+), 68 deletions(-) diff --git a/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs index acaa6de3..3ae890fb 100644 --- a/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs @@ -48,7 +48,12 @@ namespace Tensorflow.Gradients { var merge_grad = grad_ctxt.grad_state.switch_map.get(op); if (merge_grad != null) - throw new NotImplementedException("_SwitchGrad merge_grad != null"); + { + if (grads[1] != null) + control_flow_ops._AddNextAndBackEdge(merge_grad, grads[1], + enforce_shape_invariant: false); + return new Tensor[] { null, null }; + } else if (grads[0] != null) { merge_grad = merge(new[] { grads[0], grads[0] }, name: "b_switch")[0]; @@ -233,17 +238,9 @@ namespace Tensorflow.Gradients return grads; if (op.get_attr("is_constant")) { - throw new NotImplementedException("_EnterGrad is_constant"); - // Add a gradient accumulator for each loop invariant. - // if isinstance(grad, ops.Tensor) : - // result = grad_ctxt.AddBackpropAccumulator(op, grad) - // elif isinstance(grad, ops.IndexedSlices) : - // result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad) - // else: - // # TODO(yuanbyu, lukasr): Add support for SparseTensor. - // raise TypeError("Type %s not supported" % type(grad)) + // Add a gradient accumulator for each loop invariant. + result = grad_ctxt.AddBackpropAccumulator(op, grad); } - else { result = control_flow_ops.exit(grad); diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 8170ea6f..55b771ed 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -123,10 +123,7 @@ namespace Tensorflow { // generate gradient subgraph for op. var op = queue.Dequeue(); - if(op.name == "rnn/while/Exit") - { - } _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); { if (loop_state != null) @@ -136,6 +133,7 @@ namespace Tensorflow loop_state.ExitGradWhileContext(op, before: true); Tensor[] in_grads = null; + Func grad_fn = null; var is_partitioned_call = _IsPartitionedCall(op); var is_func_call = false; var has_out_grads = out_grads.Exists(x => x != null); @@ -143,8 +141,6 @@ namespace Tensorflow { // A grad_fn must be defined, either as a function or as None // for ops that do not have gradients. - - Func grad_fn = null; try { grad_fn = ops.get_gradient_function(op); @@ -167,61 +163,57 @@ namespace Tensorflow throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); } } + } - if (loop_state != null) - loop_state.EnterGradWhileContext(op, before: false); + if (loop_state != null) + loop_state.EnterGradWhileContext(op, before: false); - if ((is_func_call || grad_fn != null) && has_out_grads) + if ((is_func_call || grad_fn != null) && has_out_grads) + { + // NOTE: If _AggregatedGrads didn't compute a value for the i'th + // output, it means that the cost does not depend on output[i], + // therefore dC/doutput[i] is 0. + foreach (var (i, out_grad) in enumerate(out_grads)) { - // NOTE: If _AggregatedGrads didn't compute a value for the i'th - // output, it means that the cost does not depend on output[i], - // therefore dC/doutput[i] is 0. - foreach (var (i, out_grad) in enumerate(out_grads)) - { - if (out_grad == null && - (grad_fn == null || _IsTrainable(op.outputs[i]))) - { - // Only trainable outputs or outputs for a function call that - // will use SymbolicGradient get a zero gradient. Gradient - // functions should ignore the gradient for other outputs. - if (loop_state != null) - out_grads[i] = new List { loop_state.ZerosLike(op, i) }; - else - out_grads[i] = new List { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; - } - } - - tf_with(ops.name_scope(op.name + "_grad"), scope1 => + if (out_grad == null && + (grad_fn == null || _IsTrainable(op.outputs[i]))) { - if (grad_fn != null) - { - in_grads = _MaybeCompile(grad_scope, - op, - out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), - null, - grad_fn); - } + // Only trainable outputs or outputs for a function call that + // will use SymbolicGradient get a zero gradient. Gradient + // functions should ignore the gradient for other outputs. + if (loop_state != null) + out_grads[i] = new List { loop_state.ZerosLike(op, i) }; else - { - throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); - } - _VerifyGeneratedGradients(in_grads, op); - if (gate_gradients && in_grads.Count(x => x != null) > 1) - { - ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); - in_grads = control_flow_ops.tuple(in_grads); - } - }); + out_grads[i] = new List { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; + } } - else + + tf_with(ops.name_scope(op.name + "_grad"), scope1 => { - // If no grad_fn is defined or none of out_grads is available, - // just propagate a list of None backwards. - in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; - } + if (grad_fn != null) + { + in_grads = _MaybeCompile(grad_scope, + op, + out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), + null, + grad_fn); + } + else + { + throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); + } + _VerifyGeneratedGradients(in_grads, op); + if (gate_gradients && in_grads.Count(x => x != null) > 1) + { + ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); + in_grads = control_flow_ops.tuple(in_grads); + } + }); } else { + // If no grad_fn is defined or none of out_grads is available, + // just propagate a list of None backwards. in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; } @@ -370,7 +362,16 @@ namespace Tensorflow grads[op.name] = op_grads; } var t_grads = op_grads[t.value_index]; - t_grads.Add(grad); + if (t_grads.Count == 0) + t_grads.Add(grad); + else + op_grads[t.value_index][0] = grad; + + /*if (control_flow_util.IsLoopSwitch(op) && + t_grads[0] == null) + op_grads[t.value_index] = new List { grad }; + else + t_grads.Add(grad);*/ } private static IEnumerable _NonEagerInputs(Operation op, Tensor[] xs) @@ -379,7 +380,8 @@ namespace Tensorflow yield return op.inputs[i]; } - private static List> _AggregatedGrads(Dictionary>> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0) + private static List> _AggregatedGrads(Dictionary>> grads, Operation op, string gradient_uid, + ControlFlowState loop_state, int aggregation_method = 0) { var out_grads = _GetGrads(grads, op); @@ -387,7 +389,10 @@ namespace Tensorflow { if (loop_state != null) { - + if (out_grads.Count > 1 && + out_grads[1].Count > 0 && + control_flow_util.IsLoopSwitch(op)) + continue; } // Aggregate multiple gradients, and convert [] to None. diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index e8eb216f..0f9ed2eb 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -182,10 +182,7 @@ namespace Tensorflow // This will be set by self.inputs. if (op_def == null) op_def = g.GetOpDef(node_def.Op); - if(node_def.Name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc") - { - - } + var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); _is_stateful = op_def.IsStateful;