Browse Source

GradLoopState.AddBackpropAccumulatedValue

tags/v0.12
Oceania2018 6 years ago
parent
commit
07ddb74738
1 changed files with 67 additions and 35 deletions
  1. +67
    -35
      src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs

+ 67
- 35
src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs View File

@@ -78,6 +78,26 @@ namespace Tensorflow.Operations.ControlFlows
/// </summary> /// </summary>
public int pending_exits_count { get; set; } public int pending_exits_count { get; set; }
Operation _grad_sync;
public Operation grad_sync
{
get
{
if(_grad_sync == null)
{
tf_with(ops.control_dependencies(null), delegate
{
_grad_sync = gen_control_flow_ops.control_trigger(name: "b_sync");
});
_grad_sync._set_control_flow_context(_grad_context);
_grad_index.op._add_control_input(_grad_sync);
if (_grad_context.outer_context != null)
_grad_context.outer_context.AddInnerOp(_grad_sync);
}
return _grad_sync;
}
}
public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_) public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_)
{ {
// Information needed by backprop. // Information needed by backprop.
@@ -155,7 +175,7 @@ namespace Tensorflow.Operations.ControlFlows
/// <returns>The stack that contains the accumulated history of the tensor.</returns> /// <returns>The stack that contains the accumulated history of the tensor.</returns>
public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false)
{ {
using (_forward_index.graph.as_default())
_forward_index.graph.as_default();
{ {
var curr_ctxt = ops.get_default_graph()._get_control_flow_context(); var curr_ctxt = ops.get_default_graph()._get_control_flow_context();
return tf_with(ops.control_dependencies(null), delegate return tf_with(ops.control_dependencies(null), delegate
@@ -220,38 +240,33 @@ namespace Tensorflow.Operations.ControlFlows
public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false)
{ {
throw new NotImplementedException();
// history_ctxt = history_value.op._get_control_flow_context()
// # Find the cond context that controls history_value if any.
// cond_ctxt = None
// value_ctxt = value.op._get_control_flow_context()
// while value_ctxt and value_ctxt != history_ctxt:
// if isinstance(value_ctxt, CondContext):
// cond_ctxt = value_ctxt
// break
// value_ctxt = value_ctxt.outer_context
// with ops.control_dependencies(None):
// self.grad_context.Enter()
// if cond_ctxt:
// # Guard stack pop with a switch if it is controlled by a cond.
// grad_state = self
// pred = None
// while pred is None and grad_state:
// pred = grad_state.history_map.get(cond_ctxt.pred.name)
// grad_state = grad_state.outer_grad_state
// if pred is None:
// pred = cond_ctxt.pred
// branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
// history_value = _SwitchRefOrTensor(history_value, pred)[branch]
// pop = gen_data_flow_ops.stack_pop_v2(history_value,
// value.dtype.base_dtype)
// pop.set_shape(value.get_shape())
// self.grad_context.Exit()
// parallel_iterations = self.grad_context.parallel_iterations
// if parallel_iterations > 1:
// # All pops are ordered after pivot_for_body and before grad_sync.
// self.grad_sync._add_control_input(pop.op)
// return pop
var history_ctxt = history_value.op._get_control_flow_context();
// Find the cond context that controls history_value if any.
CondContext cond_ctxt = null;
Tensor pop = null;
var value_ctxt = value.op._get_control_flow_context();
while(value_ctxt != null && value_ctxt != history_ctxt)
{
if (value_ctxt is CondContext cc)
cond_ctxt = cc;
value_ctxt = value_ctxt.outer_context;
}
tf_with(ops.control_dependencies(null), delegate
{
grad_context.Enter();
if(cond_ctxt != null)
{
throw new NotImplementedException("AddBackpropAccumulatedValue");
}
pop = gen_data_flow_ops.stack_pop_v2(history_value, value.dtype.as_base_dtype());
pop.set_shape(value.TensorShape);
grad_context.Exit();
});
var parallel_iterations = grad_context.parallel_iterations;
if (parallel_iterations > 1)
// All pops are ordered after pivot_for_body and before grad_sync.
grad_sync._add_control_input(pop.op);
return pop;
} }
/// <summary> /// <summary>
@@ -272,11 +287,28 @@ namespace Tensorflow.Operations.ControlFlows
var enter_op = util.GetLoopConstantEnter(cur_value); var enter_op = util.GetLoopConstantEnter(cur_value);
if(enter_op != null) if(enter_op != null)
{ {
throw new NotImplementedException("GetRealValue");
// Special case: cur_value comes from a constant Enter node.
cur_value = enter_op.inputs[0];
cur_grad_state = cur_grad_state.outer_grad_state;
if(cur_grad_state == null)
{
// We are now outside all nested loops for this gradient(),
// so `value` is a loop invariant and there is no need to
// save the history of value. Just make cur_value to enter
// the right control flow context.
real_value = _grad_context.AddValue(cur_value);
break;
}
} }
else if (constant_op.is_constant(cur_value)) else if (constant_op.is_constant(cur_value))
{ {
throw new NotImplementedException("GetRealValue");
// We are now outside all nested loops for this gradient(),
// so `value` is a loop invariant and there is no need to
// save the history of value. Just make cur_value to enter
// the right control flow context.
real_value = constant_op.constant(
tensor_util.constant_value(cur_value), dtype: cur_value.dtype);
break;
} }
else else
{ {


Loading…
Cancel
Save