| @@ -78,6 +78,26 @@ namespace Tensorflow.Operations.ControlFlows | |||
| /// </summary> | |||
| public int pending_exits_count { get; set; } | |||
| Operation _grad_sync; | |||
| public Operation grad_sync | |||
| { | |||
| get | |||
| { | |||
| if(_grad_sync == null) | |||
| { | |||
| tf_with(ops.control_dependencies(null), delegate | |||
| { | |||
| _grad_sync = gen_control_flow_ops.control_trigger(name: "b_sync"); | |||
| }); | |||
| _grad_sync._set_control_flow_context(_grad_context); | |||
| _grad_index.op._add_control_input(_grad_sync); | |||
| if (_grad_context.outer_context != null) | |||
| _grad_context.outer_context.AddInnerOp(_grad_sync); | |||
| } | |||
| return _grad_sync; | |||
| } | |||
| } | |||
| public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_) | |||
| { | |||
| // Information needed by backprop. | |||
| @@ -155,7 +175,7 @@ namespace Tensorflow.Operations.ControlFlows | |||
| /// <returns>The stack that contains the accumulated history of the tensor.</returns> | |||
| public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) | |||
| { | |||
| using (_forward_index.graph.as_default()) | |||
| _forward_index.graph.as_default(); | |||
| { | |||
| var curr_ctxt = ops.get_default_graph()._get_control_flow_context(); | |||
| return tf_with(ops.control_dependencies(null), delegate | |||
| @@ -220,38 +240,33 @@ namespace Tensorflow.Operations.ControlFlows | |||
| public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) | |||
| { | |||
| throw new NotImplementedException(); | |||
| // history_ctxt = history_value.op._get_control_flow_context() | |||
| // # Find the cond context that controls history_value if any. | |||
| // cond_ctxt = None | |||
| // value_ctxt = value.op._get_control_flow_context() | |||
| // while value_ctxt and value_ctxt != history_ctxt: | |||
| // if isinstance(value_ctxt, CondContext): | |||
| // cond_ctxt = value_ctxt | |||
| // break | |||
| // value_ctxt = value_ctxt.outer_context | |||
| // with ops.control_dependencies(None): | |||
| // self.grad_context.Enter() | |||
| // if cond_ctxt: | |||
| // # Guard stack pop with a switch if it is controlled by a cond. | |||
| // grad_state = self | |||
| // pred = None | |||
| // while pred is None and grad_state: | |||
| // pred = grad_state.history_map.get(cond_ctxt.pred.name) | |||
| // grad_state = grad_state.outer_grad_state | |||
| // if pred is None: | |||
| // pred = cond_ctxt.pred | |||
| // branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch | |||
| // history_value = _SwitchRefOrTensor(history_value, pred)[branch] | |||
| // pop = gen_data_flow_ops.stack_pop_v2(history_value, | |||
| // value.dtype.base_dtype) | |||
| // pop.set_shape(value.get_shape()) | |||
| // self.grad_context.Exit() | |||
| // parallel_iterations = self.grad_context.parallel_iterations | |||
| // if parallel_iterations > 1: | |||
| // # All pops are ordered after pivot_for_body and before grad_sync. | |||
| // self.grad_sync._add_control_input(pop.op) | |||
| // return pop | |||
| var history_ctxt = history_value.op._get_control_flow_context(); | |||
| // Find the cond context that controls history_value if any. | |||
| CondContext cond_ctxt = null; | |||
| Tensor pop = null; | |||
| var value_ctxt = value.op._get_control_flow_context(); | |||
| while(value_ctxt != null && value_ctxt != history_ctxt) | |||
| { | |||
| if (value_ctxt is CondContext cc) | |||
| cond_ctxt = cc; | |||
| value_ctxt = value_ctxt.outer_context; | |||
| } | |||
| tf_with(ops.control_dependencies(null), delegate | |||
| { | |||
| grad_context.Enter(); | |||
| if(cond_ctxt != null) | |||
| { | |||
| throw new NotImplementedException("AddBackpropAccumulatedValue"); | |||
| } | |||
| pop = gen_data_flow_ops.stack_pop_v2(history_value, value.dtype.as_base_dtype()); | |||
| pop.set_shape(value.TensorShape); | |||
| grad_context.Exit(); | |||
| }); | |||
| var parallel_iterations = grad_context.parallel_iterations; | |||
| if (parallel_iterations > 1) | |||
| // All pops are ordered after pivot_for_body and before grad_sync. | |||
| grad_sync._add_control_input(pop.op); | |||
| return pop; | |||
| } | |||
| /// <summary> | |||
| @@ -272,11 +287,28 @@ namespace Tensorflow.Operations.ControlFlows | |||
| var enter_op = util.GetLoopConstantEnter(cur_value); | |||
| if(enter_op != null) | |||
| { | |||
| throw new NotImplementedException("GetRealValue"); | |||
| // Special case: cur_value comes from a constant Enter node. | |||
| cur_value = enter_op.inputs[0]; | |||
| cur_grad_state = cur_grad_state.outer_grad_state; | |||
| if(cur_grad_state == null) | |||
| { | |||
| // We are now outside all nested loops for this gradient(), | |||
| // so `value` is a loop invariant and there is no need to | |||
| // save the history of value. Just make cur_value to enter | |||
| // the right control flow context. | |||
| real_value = _grad_context.AddValue(cur_value); | |||
| break; | |||
| } | |||
| } | |||
| else if (constant_op.is_constant(cur_value)) | |||
| { | |||
| throw new NotImplementedException("GetRealValue"); | |||
| // We are now outside all nested loops for this gradient(), | |||
| // so `value` is a loop invariant and there is no need to | |||
| // save the history of value. Just make cur_value to enter | |||
| // the right control flow context. | |||
| real_value = constant_op.constant( | |||
| tensor_util.constant_value(cur_value), dtype: cur_value.dtype); | |||
| break; | |||
| } | |||
| else | |||
| { | |||