diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
index 02a5a573..fa7a77a6 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
@@ -446,6 +446,84 @@ namespace Tensorflow.Operations
return (total_iterations, next_n);
}
+ ///
+ /// Add an accumulation loop for every loop invariant.
+ ///
+ /// The Enter op for a loop invariant.
+ /// The partial gradient of an iteration for a loop invariant.
+ /// The gradient for a loop invariant.
+ public Tensor AddBackpropAccumulator(Operation op, Tensor grad)
+ {
+ Tensor acc = null;
+ Exit();
+ // Create a zeros tensor with the right shape for acc. If we don't
+ // know the full shape statically, we will have to get the shape
+ // dynamically from the forward inference. Getting the shape right
+ // for the zeros is only needed for the base case when the loop exits
+ // without running any iterations.
+ var shape = grad.TensorShape;
+ if (shape.is_fully_defined())
+ {
+ if (outer_context != null)
+ outer_context.Enter();
+ acc = constant_op.constant(0, grad.dtype, shape: shape, name: "b_acc");
+ if (outer_context != null)
+ outer_context.Exit();
+ }
+ else
+ {
+ var value = op.inputs[0];
+ if(outer_context is WhileContext wc)
+ {
+ // We are in a nested while loop.
+ var forward_ctxt = grad_state.forward_context;
+ forward_ctxt.outer_context.Enter();
+ var zeros_shape = array_ops.shape_internal(value, optimize: false);
+ forward_ctxt.outer_context.Exit();
+ var outer_grad_state = grad_state.outer_grad_state;
+ var history_zeros_shape = outer_grad_state.AddForwardAccumulator(zeros_shape);
+ outer_context.Enter();
+ var real_shape = outer_grad_state.AddBackpropAccumulatedValue(
+ history_zeros_shape, zeros_shape);
+ acc = array_ops.zeros(real_shape, grad.dtype);
+ outer_context.Exit();
+ }
+ else
+ {
+ if (outer_context != null)
+ outer_context.Enter();
+ var zeros_shape = array_ops.shape_internal(value, optimize: false);
+ acc = array_ops.zeros(zeros_shape, grad.dtype);
+ if (outer_context != null)
+ outer_context.Exit();
+ }
+ throw new NotImplementedException("AddBackpropAccumulator");
+ }
+
+ Enter();
+ AddName(acc.name);
+ var enter_acc = _Enter(
+ acc,
+ _name,
+ is_constant: false,
+ parallel_iterations: _parallel_iterations,
+ name: "b_acc");
+ loop_enters.append(enter_acc);
+ var merge_acc = merge(new[] { enter_acc, enter_acc }, name: "b_acc")[0];
+
+ var switch_result = @switch(merge_acc, _pivot);
+ var (switch_acc_false, switch_acc_true) = (switch_result[0], switch_result[1]);
+
+ var add_acc = math_ops.add(switch_acc_true, grad);
+ var next_acc = _NextIteration(add_acc);
+ merge_acc.op._update_input(1, next_acc);
+
+ var result_acc = exit(switch_acc_false, name: "b_acc");
+ loop_exits.append(result_acc);
+ ExitResult(new[] { result_acc });
+ return result_acc;
+ }
+
///
/// Add the backprop loop that controls the iterations.
///