| @@ -446,6 +446,84 @@ namespace Tensorflow.Operations | |||||
| return (total_iterations, next_n); | return (total_iterations, next_n); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Add an accumulation loop for every loop invariant. | |||||
| /// </summary> | |||||
| /// <param name="op">The Enter op for a loop invariant.</param> | |||||
| /// <param name="grad">The partial gradient of an iteration for a loop invariant.</param> | |||||
| /// <returns>The gradient for a loop invariant.</returns> | |||||
| public Tensor AddBackpropAccumulator(Operation op, Tensor grad) | |||||
| { | |||||
| Tensor acc = null; | |||||
| Exit(); | |||||
| // Create a zeros tensor with the right shape for acc. If we don't | |||||
| // know the full shape statically, we will have to get the shape | |||||
| // dynamically from the forward inference. Getting the shape right | |||||
| // for the zeros is only needed for the base case when the loop exits | |||||
| // without running any iterations. | |||||
| var shape = grad.TensorShape; | |||||
| if (shape.is_fully_defined()) | |||||
| { | |||||
| if (outer_context != null) | |||||
| outer_context.Enter(); | |||||
| acc = constant_op.constant(0, grad.dtype, shape: shape, name: "b_acc"); | |||||
| if (outer_context != null) | |||||
| outer_context.Exit(); | |||||
| } | |||||
| else | |||||
| { | |||||
| var value = op.inputs[0]; | |||||
| if(outer_context is WhileContext wc) | |||||
| { | |||||
| // We are in a nested while loop. | |||||
| var forward_ctxt = grad_state.forward_context; | |||||
| forward_ctxt.outer_context.Enter(); | |||||
| var zeros_shape = array_ops.shape_internal(value, optimize: false); | |||||
| forward_ctxt.outer_context.Exit(); | |||||
| var outer_grad_state = grad_state.outer_grad_state; | |||||
| var history_zeros_shape = outer_grad_state.AddForwardAccumulator(zeros_shape); | |||||
| outer_context.Enter(); | |||||
| var real_shape = outer_grad_state.AddBackpropAccumulatedValue( | |||||
| history_zeros_shape, zeros_shape); | |||||
| acc = array_ops.zeros(real_shape, grad.dtype); | |||||
| outer_context.Exit(); | |||||
| } | |||||
| else | |||||
| { | |||||
| if (outer_context != null) | |||||
| outer_context.Enter(); | |||||
| var zeros_shape = array_ops.shape_internal(value, optimize: false); | |||||
| acc = array_ops.zeros(zeros_shape, grad.dtype); | |||||
| if (outer_context != null) | |||||
| outer_context.Exit(); | |||||
| } | |||||
| throw new NotImplementedException("AddBackpropAccumulator"); | |||||
| } | |||||
| Enter(); | |||||
| AddName(acc.name); | |||||
| var enter_acc = _Enter( | |||||
| acc, | |||||
| _name, | |||||
| is_constant: false, | |||||
| parallel_iterations: _parallel_iterations, | |||||
| name: "b_acc"); | |||||
| loop_enters.append(enter_acc); | |||||
| var merge_acc = merge(new[] { enter_acc, enter_acc }, name: "b_acc")[0]; | |||||
| var switch_result = @switch(merge_acc, _pivot); | |||||
| var (switch_acc_false, switch_acc_true) = (switch_result[0], switch_result[1]); | |||||
| var add_acc = math_ops.add(switch_acc_true, grad); | |||||
| var next_acc = _NextIteration(add_acc); | |||||
| merge_acc.op._update_input(1, next_acc); | |||||
| var result_acc = exit(switch_acc_false, name: "b_acc"); | |||||
| loop_exits.append(result_acc); | |||||
| ExitResult(new[] { result_acc }); | |||||
| return result_acc; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Add the backprop loop that controls the iterations. | /// Add the backprop loop that controls the iterations. | ||||
| /// </summary> | /// </summary> | ||||