diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index 02a5a573..fa7a77a6 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -446,6 +446,84 @@ namespace Tensorflow.Operations return (total_iterations, next_n); } + /// + /// Add an accumulation loop for every loop invariant. + /// + /// The Enter op for a loop invariant. + /// The partial gradient of an iteration for a loop invariant. + /// The gradient for a loop invariant. + public Tensor AddBackpropAccumulator(Operation op, Tensor grad) + { + Tensor acc = null; + Exit(); + // Create a zeros tensor with the right shape for acc. If we don't + // know the full shape statically, we will have to get the shape + // dynamically from the forward inference. Getting the shape right + // for the zeros is only needed for the base case when the loop exits + // without running any iterations. + var shape = grad.TensorShape; + if (shape.is_fully_defined()) + { + if (outer_context != null) + outer_context.Enter(); + acc = constant_op.constant(0, grad.dtype, shape: shape, name: "b_acc"); + if (outer_context != null) + outer_context.Exit(); + } + else + { + var value = op.inputs[0]; + if(outer_context is WhileContext wc) + { + // We are in a nested while loop. + var forward_ctxt = grad_state.forward_context; + forward_ctxt.outer_context.Enter(); + var zeros_shape = array_ops.shape_internal(value, optimize: false); + forward_ctxt.outer_context.Exit(); + var outer_grad_state = grad_state.outer_grad_state; + var history_zeros_shape = outer_grad_state.AddForwardAccumulator(zeros_shape); + outer_context.Enter(); + var real_shape = outer_grad_state.AddBackpropAccumulatedValue( + history_zeros_shape, zeros_shape); + acc = array_ops.zeros(real_shape, grad.dtype); + outer_context.Exit(); + } + else + { + if (outer_context != null) + outer_context.Enter(); + var zeros_shape = array_ops.shape_internal(value, optimize: false); + acc = array_ops.zeros(zeros_shape, grad.dtype); + if (outer_context != null) + outer_context.Exit(); + } + throw new NotImplementedException("AddBackpropAccumulator"); + } + + Enter(); + AddName(acc.name); + var enter_acc = _Enter( + acc, + _name, + is_constant: false, + parallel_iterations: _parallel_iterations, + name: "b_acc"); + loop_enters.append(enter_acc); + var merge_acc = merge(new[] { enter_acc, enter_acc }, name: "b_acc")[0]; + + var switch_result = @switch(merge_acc, _pivot); + var (switch_acc_false, switch_acc_true) = (switch_result[0], switch_result[1]); + + var add_acc = math_ops.add(switch_acc_true, grad); + var next_acc = _NextIteration(add_acc); + merge_acc.op._update_input(1, next_acc); + + var result_acc = exit(switch_acc_false, name: "b_acc"); + loop_exits.append(result_acc); + ExitResult(new[] { result_acc }); + return result_acc; + } + /// /// Add the backprop loop that controls the iterations. ///