Browse Source

WhileContext AddBackpropAccumulator

tags/v0.12
Oceania2018 6 years ago
parent
commit
3e07372855
1 changed files with 78 additions and 0 deletions
  1. +78
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs

+ 78
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs View File

@@ -446,6 +446,84 @@ namespace Tensorflow.Operations
return (total_iterations, next_n);
}

/// <summary>
/// Add an accumulation loop for every loop invariant.
/// </summary>
/// <param name="op">The Enter op for a loop invariant.</param>
/// <param name="grad">The partial gradient of an iteration for a loop invariant.</param>
/// <returns>The gradient for a loop invariant.</returns>
public Tensor AddBackpropAccumulator(Operation op, Tensor grad)
{
Tensor acc = null;
Exit();
// Create a zeros tensor with the right shape for acc. If we don't
// know the full shape statically, we will have to get the shape
// dynamically from the forward inference. Getting the shape right
// for the zeros is only needed for the base case when the loop exits
// without running any iterations.
var shape = grad.TensorShape;
if (shape.is_fully_defined())
{
if (outer_context != null)
outer_context.Enter();
acc = constant_op.constant(0, grad.dtype, shape: shape, name: "b_acc");
if (outer_context != null)
outer_context.Exit();
}
else
{
var value = op.inputs[0];
if(outer_context is WhileContext wc)
{
// We are in a nested while loop.
var forward_ctxt = grad_state.forward_context;
forward_ctxt.outer_context.Enter();
var zeros_shape = array_ops.shape_internal(value, optimize: false);
forward_ctxt.outer_context.Exit();
var outer_grad_state = grad_state.outer_grad_state;
var history_zeros_shape = outer_grad_state.AddForwardAccumulator(zeros_shape);
outer_context.Enter();
var real_shape = outer_grad_state.AddBackpropAccumulatedValue(
history_zeros_shape, zeros_shape);
acc = array_ops.zeros(real_shape, grad.dtype);
outer_context.Exit();
}
else
{
if (outer_context != null)
outer_context.Enter();
var zeros_shape = array_ops.shape_internal(value, optimize: false);
acc = array_ops.zeros(zeros_shape, grad.dtype);
if (outer_context != null)
outer_context.Exit();
}
throw new NotImplementedException("AddBackpropAccumulator");
}

Enter();
AddName(acc.name);
var enter_acc = _Enter(
acc,
_name,
is_constant: false,
parallel_iterations: _parallel_iterations,
name: "b_acc");
loop_enters.append(enter_acc);
var merge_acc = merge(new[] { enter_acc, enter_acc }, name: "b_acc")[0];

var switch_result = @switch(merge_acc, _pivot);
var (switch_acc_false, switch_acc_true) = (switch_result[0], switch_result[1]);

var add_acc = math_ops.add(switch_acc_true, grad);
var next_acc = _NextIteration(add_acc);
merge_acc.op._update_input(1, next_acc);

var result_acc = exit(switch_acc_false, name: "b_acc");
loop_exits.append(result_acc);
ExitResult(new[] { result_acc });
return result_acc;
}

/// <summary>
/// Add the backprop loop that controls the iterations.
/// </summary>


Loading…
Cancel
Save