| @@ -45,7 +45,19 @@ namespace Tensorflow.Gradients | |||||
| switch (op_ctxt) | switch (op_ctxt) | ||||
| { | { | ||||
| case WhileContext cwhile: | case WhileContext cwhile: | ||||
| throw new NotImplementedException("_SwitchGrad WhileContext"); | |||||
| { | |||||
| var merge_grad = grad_ctxt.grad_state.switch_map.get(op); | |||||
| if (merge_grad != null) | |||||
| throw new NotImplementedException("_SwitchGrad merge_grad != null"); | |||||
| else if (grads[0] != null) | |||||
| { | |||||
| merge_grad = merge(new[] { grads[0], grads[0] }, name: "b_switch")[0]; | |||||
| grad_ctxt.grad_state.switch_map[op] = merge_grad; | |||||
| return new Tensor[] { merge_grad, null }; | |||||
| } | |||||
| else | |||||
| return new Tensor[] { null, null }; | |||||
| } | |||||
| case CondContext ccond: | case CondContext ccond: | ||||
| { | { | ||||
| var zero_grad = grads[1 - op_ctxt.branch]; | var zero_grad = grads[1 - op_ctxt.branch]; | ||||
| @@ -74,7 +86,7 @@ namespace Tensorflow.Gradients | |||||
| /// <param name="inputs"></param> | /// <param name="inputs"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| internal static Tensor[] merge(Tensor[] inputs, string name = null) | |||||
| internal static MergeOutput merge(Tensor[] inputs, string name = null) | |||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "Merge", inputs), scope => | return tf_with(ops.name_scope(name, "Merge", inputs), scope => | ||||
| { | { | ||||
| @@ -146,7 +158,7 @@ namespace Tensorflow.Gradients | |||||
| } | } | ||||
| [RegisterGradient("RefMerge")] | [RegisterGradient("RefMerge")] | ||||
| public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) | |||||
| public static Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) | |||||
| { | { | ||||
| return _MergeGrad(op, grads); | return _MergeGrad(op, grads); | ||||
| } | } | ||||
| @@ -155,43 +167,32 @@ namespace Tensorflow.Gradients | |||||
| /// Gradients for an exit op are calculated using an Enter op. | /// Gradients for an exit op are calculated using an Enter op. | ||||
| /// </summary> | /// </summary> | ||||
| [RegisterGradient("Exit")] | [RegisterGradient("Exit")] | ||||
| public Tensor[] _ExitGrad(Operation op, Tensor[] grads) | |||||
| public static Tensor[] _ExitGrad(Operation op, Tensor[] grads) | |||||
| { | { | ||||
| throw new NotImplementedException("_ExitGrad"); | |||||
| // graph = ops.get_default_graph() | |||||
| //# pylint: disable=protected-access | |||||
| // op_ctxt = op._get_control_flow_context() | |||||
| // grad_ctxt = graph._get_control_flow_context() | |||||
| // # pylint: enable=protected-access | |||||
| // if not grad_ctxt.back_prop: | |||||
| // # The flag `back_prop` is set by users to suppress gradient | |||||
| // # computation for this loop. If the attribute `back_prop` is false, | |||||
| // # no gradient computation. | |||||
| // return None | |||||
| var grad = grads[0]; | |||||
| var graph = ops.get_default_graph(); | |||||
| var op_ctxt = op._get_control_flow_context(); | |||||
| var grad_ctxt = graph._get_control_flow_context() as WhileContext; | |||||
| // The flag `back_prop` is set by users to suppress gradient | |||||
| // computation for this loop. If the attribute `back_prop` is false, | |||||
| // no gradient computation. | |||||
| if (!grad_ctxt.back_prop) | |||||
| return null; | |||||
| if (op_ctxt.grad_state != null) | |||||
| throw new TypeError("Second-order gradient for while loops not supported."); | |||||
| grad_ctxt.AddName(grad.name); | |||||
| // if op_ctxt.grad_state: | |||||
| // raise TypeError("Second-order gradient for while loops not supported.") | |||||
| grad_ctxt.Enter(); | |||||
| var result = control_flow_ops._Enter( | |||||
| grad, grad_ctxt.name, is_constant: false, | |||||
| parallel_iterations: grad_ctxt.parallel_iterations, | |||||
| name: "b_exit"); | |||||
| // if isinstance(grad, ops.Tensor) : | |||||
| // grad_ctxt.AddName(grad.name) | |||||
| // else: | |||||
| // if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)): | |||||
| // raise TypeError("Type %s not supported" % type(grad)) | |||||
| // grad_ctxt.AddName(grad.values.name) | |||||
| // grad_ctxt.AddName(grad.indices.name) | |||||
| // dense_shape = grad.dense_shape | |||||
| // if dense_shape is not None: | |||||
| // grad_ctxt.AddName(dense_shape.name) | |||||
| // grad_ctxt.Enter() | |||||
| // # pylint: disable=protected-access | |||||
| // result = control_flow_ops._Enter( | |||||
| // grad, grad_ctxt.name, is_constant=False, | |||||
| // parallel_iterations=grad_ctxt.parallel_iterations, | |||||
| // name="b_exit") | |||||
| // # pylint: enable=protected-access | |||||
| // grad_ctxt.loop_enters.append(result) | |||||
| // grad_ctxt.Exit() | |||||
| // return result | |||||
| grad_ctxt.loop_enters.append(result); | |||||
| grad_ctxt.Exit(); | |||||
| return new[] { result }; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -200,15 +201,15 @@ namespace Tensorflow.Gradients | |||||
| /// Note that the backprop next_iteration is added in switch grad. | /// Note that the backprop next_iteration is added in switch grad. | ||||
| /// </summary> | /// </summary> | ||||
| [RegisterGradient("NextIteration")] | [RegisterGradient("NextIteration")] | ||||
| public Tensor[] _NextIterationGrad(object _, Tensor[] grad) | |||||
| public static Tensor[] _NextIterationGrad(Operation op, Tensor[] grads) | |||||
| { | { | ||||
| return grad; | |||||
| return grads; | |||||
| } | } | ||||
| [RegisterGradient("RefNextIteration")] | [RegisterGradient("RefNextIteration")] | ||||
| public Tensor[] _RefNextIterationGrad(object _, Tensor[] grad) | |||||
| public static Tensor[] _RefNextIterationGrad(Operation op, Tensor[] grads) | |||||
| { | { | ||||
| return grad; | |||||
| return grads; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -218,33 +219,39 @@ namespace Tensorflow.Gradients | |||||
| /// For loop invariants, we need to add an accumulator loop. | /// For loop invariants, we need to add an accumulator loop. | ||||
| /// </summary> | /// </summary> | ||||
| [RegisterGradient("Enter")] | [RegisterGradient("Enter")] | ||||
| public Tensor[] _EnterGrad(Tensor op, Tensor[] grad) | |||||
| public static Tensor[] _EnterGrad(Operation op, Tensor[] grads) | |||||
| { | { | ||||
| throw new NotImplementedException("_EnterGrad"); | |||||
| // graph = ops.get_default_graph() | |||||
| //# pylint: disable=protected-access | |||||
| // grad_ctxt = graph._get_control_flow_context() | |||||
| // # pylint: enable=protected-access | |||||
| // if not grad_ctxt.back_prop: | |||||
| // # Skip gradient computation, if the attribute `back_prop` is false. | |||||
| // return grad | |||||
| // if grad_ctxt.grad_state is None: | |||||
| // # Pass the gradient through if we are not in a gradient while context. | |||||
| // return grad | |||||
| // if op.get_attr("is_constant"): | |||||
| // # Add a gradient accumulator for each loop invariant. | |||||
| // if isinstance(grad, ops.Tensor) : | |||||
| // result = grad_ctxt.AddBackpropAccumulator(op, grad) | |||||
| // elif isinstance(grad, ops.IndexedSlices) : | |||||
| // result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad) | |||||
| // else: | |||||
| // # TODO(yuanbyu, lukasr): Add support for SparseTensor. | |||||
| // raise TypeError("Type %s not supported" % type(grad)) | |||||
| // else: | |||||
| // result = exit(grad) | |||||
| // grad_ctxt.loop_exits.append(result) | |||||
| // grad_ctxt.ExitResult([result]) | |||||
| // return result | |||||
| Tensor result = null; | |||||
| var grad = grads[0]; | |||||
| var graph = ops.get_default_graph(); | |||||
| var grad_ctxt = graph._get_control_flow_context() as WhileContext; | |||||
| if (!grad_ctxt.back_prop) | |||||
| // Skip gradient computation, if the attribute `back_prop` is false. | |||||
| return grads; | |||||
| if (grad_ctxt.grad_state == null) | |||||
| // Pass the gradient through if we are not in a gradient while context. | |||||
| return grads; | |||||
| if (op.get_attr<bool>("is_constant")) | |||||
| { | |||||
| throw new NotImplementedException("_EnterGrad is_constant"); | |||||
| // Add a gradient accumulator for each loop invariant. | |||||
| // if isinstance(grad, ops.Tensor) : | |||||
| // result = grad_ctxt.AddBackpropAccumulator(op, grad) | |||||
| // elif isinstance(grad, ops.IndexedSlices) : | |||||
| // result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad) | |||||
| // else: | |||||
| // # TODO(yuanbyu, lukasr): Add support for SparseTensor. | |||||
| // raise TypeError("Type %s not supported" % type(grad)) | |||||
| } | |||||
| else | |||||
| { | |||||
| result = control_flow_ops.exit(grad); | |||||
| grad_ctxt.loop_exits.append(result); | |||||
| grad_ctxt.ExitResult(new[] { result }); | |||||
| } | |||||
| return new Tensor[] { result }; | |||||
| } | } | ||||