diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs
index 143aacb1..2552df8a 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs
@@ -78,6 +78,26 @@ namespace Tensorflow.Operations.ControlFlows
///
public int pending_exits_count { get; set; }
+ Operation _grad_sync;
+ public Operation grad_sync
+ {
+ get
+ {
+ if(_grad_sync == null)
+ {
+ tf_with(ops.control_dependencies(null), delegate
+ {
+ _grad_sync = gen_control_flow_ops.control_trigger(name: "b_sync");
+ });
+ _grad_sync._set_control_flow_context(_grad_context);
+ _grad_index.op._add_control_input(_grad_sync);
+ if (_grad_context.outer_context != null)
+ _grad_context.outer_context.AddInnerOp(_grad_sync);
+ }
+ return _grad_sync;
+ }
+ }
+
public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_)
{
// Information needed by backprop.
@@ -155,7 +175,7 @@ namespace Tensorflow.Operations.ControlFlows
/// The stack that contains the accumulated history of the tensor.
public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false)
{
- using (_forward_index.graph.as_default())
+ _forward_index.graph.as_default();
{
var curr_ctxt = ops.get_default_graph()._get_control_flow_context();
return tf_with(ops.control_dependencies(null), delegate
@@ -220,38 +240,33 @@ namespace Tensorflow.Operations.ControlFlows
public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false)
{
- throw new NotImplementedException();
- // history_ctxt = history_value.op._get_control_flow_context()
- // # Find the cond context that controls history_value if any.
- // cond_ctxt = None
- // value_ctxt = value.op._get_control_flow_context()
- // while value_ctxt and value_ctxt != history_ctxt:
- // if isinstance(value_ctxt, CondContext):
- // cond_ctxt = value_ctxt
- // break
- // value_ctxt = value_ctxt.outer_context
- // with ops.control_dependencies(None):
- // self.grad_context.Enter()
- // if cond_ctxt:
- // # Guard stack pop with a switch if it is controlled by a cond.
- // grad_state = self
- // pred = None
- // while pred is None and grad_state:
- // pred = grad_state.history_map.get(cond_ctxt.pred.name)
- // grad_state = grad_state.outer_grad_state
- // if pred is None:
- // pred = cond_ctxt.pred
- // branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
- // history_value = _SwitchRefOrTensor(history_value, pred)[branch]
- // pop = gen_data_flow_ops.stack_pop_v2(history_value,
- // value.dtype.base_dtype)
- // pop.set_shape(value.get_shape())
- // self.grad_context.Exit()
- // parallel_iterations = self.grad_context.parallel_iterations
- // if parallel_iterations > 1:
- // # All pops are ordered after pivot_for_body and before grad_sync.
- // self.grad_sync._add_control_input(pop.op)
- // return pop
+ var history_ctxt = history_value.op._get_control_flow_context();
+ // Find the cond context that controls history_value if any.
+ CondContext cond_ctxt = null;
+ Tensor pop = null;
+ var value_ctxt = value.op._get_control_flow_context();
+ while(value_ctxt != null && value_ctxt != history_ctxt)
+ {
+ if (value_ctxt is CondContext cc)
+ cond_ctxt = cc;
+ value_ctxt = value_ctxt.outer_context;
+ }
+ tf_with(ops.control_dependencies(null), delegate
+ {
+ grad_context.Enter();
+ if(cond_ctxt != null)
+ {
+ throw new NotImplementedException("AddBackpropAccumulatedValue");
+ }
+ pop = gen_data_flow_ops.stack_pop_v2(history_value, value.dtype.as_base_dtype());
+ pop.set_shape(value.TensorShape);
+ grad_context.Exit();
+ });
+ var parallel_iterations = grad_context.parallel_iterations;
+ if (parallel_iterations > 1)
+ // All pops are ordered after pivot_for_body and before grad_sync.
+ grad_sync._add_control_input(pop.op);
+ return pop;
}
///
@@ -272,11 +287,28 @@ namespace Tensorflow.Operations.ControlFlows
var enter_op = util.GetLoopConstantEnter(cur_value);
if(enter_op != null)
{
- throw new NotImplementedException("GetRealValue");
+ // Special case: cur_value comes from a constant Enter node.
+ cur_value = enter_op.inputs[0];
+ cur_grad_state = cur_grad_state.outer_grad_state;
+ if(cur_grad_state == null)
+ {
+ // We are now outside all nested loops for this gradient(),
+ // so `value` is a loop invariant and there is no need to
+ // save the history of value. Just make cur_value to enter
+ // the right control flow context.
+ real_value = _grad_context.AddValue(cur_value);
+ break;
+ }
}
else if (constant_op.is_constant(cur_value))
{
- throw new NotImplementedException("GetRealValue");
+ // We are now outside all nested loops for this gradient(),
+ // so `value` is a loop invariant and there is no need to
+ // save the history of value. Just make cur_value to enter
+ // the right control flow context.
+ real_value = constant_op.constant(
+ tensor_util.constant_value(cur_value), dtype: cur_value.dtype);
+ break;
}
else
{