| @@ -446,6 +446,84 @@ namespace Tensorflow.Operations | |||
| return (total_iterations, next_n); | |||
| } | |||
| /// <summary> | |||
| /// Add an accumulation loop for every loop invariant. | |||
| /// </summary> | |||
| /// <param name="op">The Enter op for a loop invariant.</param> | |||
| /// <param name="grad">The partial gradient of an iteration for a loop invariant.</param> | |||
| /// <returns>The gradient for a loop invariant.</returns> | |||
| public Tensor AddBackpropAccumulator(Operation op, Tensor grad) | |||
| { | |||
| Tensor acc = null; | |||
| Exit(); | |||
| // Create a zeros tensor with the right shape for acc. If we don't | |||
| // know the full shape statically, we will have to get the shape | |||
| // dynamically from the forward inference. Getting the shape right | |||
| // for the zeros is only needed for the base case when the loop exits | |||
| // without running any iterations. | |||
| var shape = grad.TensorShape; | |||
| if (shape.is_fully_defined()) | |||
| { | |||
| if (outer_context != null) | |||
| outer_context.Enter(); | |||
| acc = constant_op.constant(0, grad.dtype, shape: shape, name: "b_acc"); | |||
| if (outer_context != null) | |||
| outer_context.Exit(); | |||
| } | |||
| else | |||
| { | |||
| var value = op.inputs[0]; | |||
| if(outer_context is WhileContext wc) | |||
| { | |||
| // We are in a nested while loop. | |||
| var forward_ctxt = grad_state.forward_context; | |||
| forward_ctxt.outer_context.Enter(); | |||
| var zeros_shape = array_ops.shape_internal(value, optimize: false); | |||
| forward_ctxt.outer_context.Exit(); | |||
| var outer_grad_state = grad_state.outer_grad_state; | |||
| var history_zeros_shape = outer_grad_state.AddForwardAccumulator(zeros_shape); | |||
| outer_context.Enter(); | |||
| var real_shape = outer_grad_state.AddBackpropAccumulatedValue( | |||
| history_zeros_shape, zeros_shape); | |||
| acc = array_ops.zeros(real_shape, grad.dtype); | |||
| outer_context.Exit(); | |||
| } | |||
| else | |||
| { | |||
| if (outer_context != null) | |||
| outer_context.Enter(); | |||
| var zeros_shape = array_ops.shape_internal(value, optimize: false); | |||
| acc = array_ops.zeros(zeros_shape, grad.dtype); | |||
| if (outer_context != null) | |||
| outer_context.Exit(); | |||
| } | |||
| throw new NotImplementedException("AddBackpropAccumulator"); | |||
| } | |||
| Enter(); | |||
| AddName(acc.name); | |||
| var enter_acc = _Enter( | |||
| acc, | |||
| _name, | |||
| is_constant: false, | |||
| parallel_iterations: _parallel_iterations, | |||
| name: "b_acc"); | |||
| loop_enters.append(enter_acc); | |||
| var merge_acc = merge(new[] { enter_acc, enter_acc }, name: "b_acc")[0]; | |||
| var switch_result = @switch(merge_acc, _pivot); | |||
| var (switch_acc_false, switch_acc_true) = (switch_result[0], switch_result[1]); | |||
| var add_acc = math_ops.add(switch_acc_true, grad); | |||
| var next_acc = _NextIteration(add_acc); | |||
| merge_acc.op._update_input(1, next_acc); | |||
| var result_acc = exit(switch_acc_false, name: "b_acc"); | |||
| loop_exits.append(result_acc); | |||
| ExitResult(new[] { result_acc }); | |||
| return result_acc; | |||
| } | |||
| /// <summary> | |||
| /// Add the backprop loop that controls the iterations. | |||
| /// </summary> | |||