diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs index 143aacb1..2552df8a 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs @@ -78,6 +78,26 @@ namespace Tensorflow.Operations.ControlFlows /// public int pending_exits_count { get; set; } + Operation _grad_sync; + public Operation grad_sync + { + get + { + if(_grad_sync == null) + { + tf_with(ops.control_dependencies(null), delegate + { + _grad_sync = gen_control_flow_ops.control_trigger(name: "b_sync"); + }); + _grad_sync._set_control_flow_context(_grad_context); + _grad_index.op._add_control_input(_grad_sync); + if (_grad_context.outer_context != null) + _grad_context.outer_context.AddInnerOp(_grad_sync); + } + return _grad_sync; + } + } + public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_) { // Information needed by backprop. @@ -155,7 +175,7 @@ namespace Tensorflow.Operations.ControlFlows /// The stack that contains the accumulated history of the tensor. public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) { - using (_forward_index.graph.as_default()) + _forward_index.graph.as_default(); { var curr_ctxt = ops.get_default_graph()._get_control_flow_context(); return tf_with(ops.control_dependencies(null), delegate @@ -220,38 +240,33 @@ namespace Tensorflow.Operations.ControlFlows public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) { - throw new NotImplementedException(); - // history_ctxt = history_value.op._get_control_flow_context() - // # Find the cond context that controls history_value if any. - // cond_ctxt = None - // value_ctxt = value.op._get_control_flow_context() - // while value_ctxt and value_ctxt != history_ctxt: - // if isinstance(value_ctxt, CondContext): - // cond_ctxt = value_ctxt - // break - // value_ctxt = value_ctxt.outer_context - // with ops.control_dependencies(None): - // self.grad_context.Enter() - // if cond_ctxt: - // # Guard stack pop with a switch if it is controlled by a cond. - // grad_state = self - // pred = None - // while pred is None and grad_state: - // pred = grad_state.history_map.get(cond_ctxt.pred.name) - // grad_state = grad_state.outer_grad_state - // if pred is None: - // pred = cond_ctxt.pred - // branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch - // history_value = _SwitchRefOrTensor(history_value, pred)[branch] - // pop = gen_data_flow_ops.stack_pop_v2(history_value, - // value.dtype.base_dtype) - // pop.set_shape(value.get_shape()) - // self.grad_context.Exit() - // parallel_iterations = self.grad_context.parallel_iterations - // if parallel_iterations > 1: - // # All pops are ordered after pivot_for_body and before grad_sync. - // self.grad_sync._add_control_input(pop.op) - // return pop + var history_ctxt = history_value.op._get_control_flow_context(); + // Find the cond context that controls history_value if any. + CondContext cond_ctxt = null; + Tensor pop = null; + var value_ctxt = value.op._get_control_flow_context(); + while(value_ctxt != null && value_ctxt != history_ctxt) + { + if (value_ctxt is CondContext cc) + cond_ctxt = cc; + value_ctxt = value_ctxt.outer_context; + } + tf_with(ops.control_dependencies(null), delegate + { + grad_context.Enter(); + if(cond_ctxt != null) + { + throw new NotImplementedException("AddBackpropAccumulatedValue"); + } + pop = gen_data_flow_ops.stack_pop_v2(history_value, value.dtype.as_base_dtype()); + pop.set_shape(value.TensorShape); + grad_context.Exit(); + }); + var parallel_iterations = grad_context.parallel_iterations; + if (parallel_iterations > 1) + // All pops are ordered after pivot_for_body and before grad_sync. + grad_sync._add_control_input(pop.op); + return pop; } /// @@ -272,11 +287,28 @@ namespace Tensorflow.Operations.ControlFlows var enter_op = util.GetLoopConstantEnter(cur_value); if(enter_op != null) { - throw new NotImplementedException("GetRealValue"); + // Special case: cur_value comes from a constant Enter node. + cur_value = enter_op.inputs[0]; + cur_grad_state = cur_grad_state.outer_grad_state; + if(cur_grad_state == null) + { + // We are now outside all nested loops for this gradient(), + // so `value` is a loop invariant and there is no need to + // save the history of value. Just make cur_value to enter + // the right control flow context. + real_value = _grad_context.AddValue(cur_value); + break; + } } else if (constant_op.is_constant(cur_value)) { - throw new NotImplementedException("GetRealValue"); + // We are now outside all nested loops for this gradient(), + // so `value` is a loop invariant and there is no need to + // save the history of value. Just make cur_value to enter + // the right control flow context. + real_value = constant_op.constant( + tensor_util.constant_value(cur_value), dtype: cur_value.dtype); + break; } else {