| @@ -78,6 +78,26 @@ namespace Tensorflow.Operations.ControlFlows | |||||
| /// </summary> | /// </summary> | ||||
| public int pending_exits_count { get; set; } | public int pending_exits_count { get; set; } | ||||
| Operation _grad_sync; | |||||
| public Operation grad_sync | |||||
| { | |||||
| get | |||||
| { | |||||
| if(_grad_sync == null) | |||||
| { | |||||
| tf_with(ops.control_dependencies(null), delegate | |||||
| { | |||||
| _grad_sync = gen_control_flow_ops.control_trigger(name: "b_sync"); | |||||
| }); | |||||
| _grad_sync._set_control_flow_context(_grad_context); | |||||
| _grad_index.op._add_control_input(_grad_sync); | |||||
| if (_grad_context.outer_context != null) | |||||
| _grad_context.outer_context.AddInnerOp(_grad_sync); | |||||
| } | |||||
| return _grad_sync; | |||||
| } | |||||
| } | |||||
| public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_) | public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_) | ||||
| { | { | ||||
| // Information needed by backprop. | // Information needed by backprop. | ||||
| @@ -155,7 +175,7 @@ namespace Tensorflow.Operations.ControlFlows | |||||
| /// <returns>The stack that contains the accumulated history of the tensor.</returns> | /// <returns>The stack that contains the accumulated history of the tensor.</returns> | ||||
| public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) | public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) | ||||
| { | { | ||||
| using (_forward_index.graph.as_default()) | |||||
| _forward_index.graph.as_default(); | |||||
| { | { | ||||
| var curr_ctxt = ops.get_default_graph()._get_control_flow_context(); | var curr_ctxt = ops.get_default_graph()._get_control_flow_context(); | ||||
| return tf_with(ops.control_dependencies(null), delegate | return tf_with(ops.control_dependencies(null), delegate | ||||
| @@ -220,38 +240,33 @@ namespace Tensorflow.Operations.ControlFlows | |||||
| public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) | public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) | ||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| // history_ctxt = history_value.op._get_control_flow_context() | |||||
| // # Find the cond context that controls history_value if any. | |||||
| // cond_ctxt = None | |||||
| // value_ctxt = value.op._get_control_flow_context() | |||||
| // while value_ctxt and value_ctxt != history_ctxt: | |||||
| // if isinstance(value_ctxt, CondContext): | |||||
| // cond_ctxt = value_ctxt | |||||
| // break | |||||
| // value_ctxt = value_ctxt.outer_context | |||||
| // with ops.control_dependencies(None): | |||||
| // self.grad_context.Enter() | |||||
| // if cond_ctxt: | |||||
| // # Guard stack pop with a switch if it is controlled by a cond. | |||||
| // grad_state = self | |||||
| // pred = None | |||||
| // while pred is None and grad_state: | |||||
| // pred = grad_state.history_map.get(cond_ctxt.pred.name) | |||||
| // grad_state = grad_state.outer_grad_state | |||||
| // if pred is None: | |||||
| // pred = cond_ctxt.pred | |||||
| // branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch | |||||
| // history_value = _SwitchRefOrTensor(history_value, pred)[branch] | |||||
| // pop = gen_data_flow_ops.stack_pop_v2(history_value, | |||||
| // value.dtype.base_dtype) | |||||
| // pop.set_shape(value.get_shape()) | |||||
| // self.grad_context.Exit() | |||||
| // parallel_iterations = self.grad_context.parallel_iterations | |||||
| // if parallel_iterations > 1: | |||||
| // # All pops are ordered after pivot_for_body and before grad_sync. | |||||
| // self.grad_sync._add_control_input(pop.op) | |||||
| // return pop | |||||
| var history_ctxt = history_value.op._get_control_flow_context(); | |||||
| // Find the cond context that controls history_value if any. | |||||
| CondContext cond_ctxt = null; | |||||
| Tensor pop = null; | |||||
| var value_ctxt = value.op._get_control_flow_context(); | |||||
| while(value_ctxt != null && value_ctxt != history_ctxt) | |||||
| { | |||||
| if (value_ctxt is CondContext cc) | |||||
| cond_ctxt = cc; | |||||
| value_ctxt = value_ctxt.outer_context; | |||||
| } | |||||
| tf_with(ops.control_dependencies(null), delegate | |||||
| { | |||||
| grad_context.Enter(); | |||||
| if(cond_ctxt != null) | |||||
| { | |||||
| throw new NotImplementedException("AddBackpropAccumulatedValue"); | |||||
| } | |||||
| pop = gen_data_flow_ops.stack_pop_v2(history_value, value.dtype.as_base_dtype()); | |||||
| pop.set_shape(value.TensorShape); | |||||
| grad_context.Exit(); | |||||
| }); | |||||
| var parallel_iterations = grad_context.parallel_iterations; | |||||
| if (parallel_iterations > 1) | |||||
| // All pops are ordered after pivot_for_body and before grad_sync. | |||||
| grad_sync._add_control_input(pop.op); | |||||
| return pop; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -272,11 +287,28 @@ namespace Tensorflow.Operations.ControlFlows | |||||
| var enter_op = util.GetLoopConstantEnter(cur_value); | var enter_op = util.GetLoopConstantEnter(cur_value); | ||||
| if(enter_op != null) | if(enter_op != null) | ||||
| { | { | ||||
| throw new NotImplementedException("GetRealValue"); | |||||
| // Special case: cur_value comes from a constant Enter node. | |||||
| cur_value = enter_op.inputs[0]; | |||||
| cur_grad_state = cur_grad_state.outer_grad_state; | |||||
| if(cur_grad_state == null) | |||||
| { | |||||
| // We are now outside all nested loops for this gradient(), | |||||
| // so `value` is a loop invariant and there is no need to | |||||
| // save the history of value. Just make cur_value to enter | |||||
| // the right control flow context. | |||||
| real_value = _grad_context.AddValue(cur_value); | |||||
| break; | |||||
| } | |||||
| } | } | ||||
| else if (constant_op.is_constant(cur_value)) | else if (constant_op.is_constant(cur_value)) | ||||
| { | { | ||||
| throw new NotImplementedException("GetRealValue"); | |||||
| // We are now outside all nested loops for this gradient(), | |||||
| // so `value` is a loop invariant and there is no need to | |||||
| // save the history of value. Just make cur_value to enter | |||||
| // the right control flow context. | |||||
| real_value = constant_op.constant( | |||||
| tensor_util.constant_value(cur_value), dtype: cur_value.dtype); | |||||
| break; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||