| @@ -20,6 +20,7 @@ using System.Linq; | |||
| using Tensorflow.Operations.ControlFlows; | |||
| using static Tensorflow.ControlFlowContextDef; | |||
| using static Tensorflow.Binding; | |||
| using util = Tensorflow.control_flow_util; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| @@ -146,6 +147,14 @@ namespace Tensorflow.Operations | |||
| graph._set_control_flow_context(last_context); | |||
| } | |||
| public void ExitResult(Tensor[] result) | |||
| { | |||
| if(_outer_context != null) | |||
| { | |||
| throw new NotImplementedException("ExitResult"); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Add `op` to the current context. | |||
| /// </summary> | |||
| @@ -172,6 +181,11 @@ namespace Tensorflow.Operations | |||
| return null; | |||
| } | |||
| public void AddName(string name) | |||
| { | |||
| _values.Add(name); | |||
| } | |||
| /// <summary> | |||
| /// Notifies a scope about an operator added to an inner scope. | |||
| /// </summary> | |||
| @@ -246,9 +260,11 @@ namespace Tensorflow.Operations | |||
| } | |||
| else | |||
| { | |||
| foreach(Tensor x in op.control_inputs) | |||
| foreach(Operation x in op.control_inputs) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| var ctxt = util.GetOutputContext(x); | |||
| if (ctxt != null && ctxt.GetWhileContext() == while_ctxt) | |||
| internal_control_inputs.append(x); | |||
| } | |||
| } | |||
| @@ -288,6 +304,14 @@ namespace Tensorflow.Operations | |||
| throw new NotImplementedException($"Unknown ControlFlowContextDef field: {context_def.CtxtCase}"); | |||
| } | |||
| public virtual bool IsWhileContext() | |||
| { | |||
| throw new NotImplementedException("IsWhileContext"); | |||
| } | |||
| public virtual bool IsCondContext() | |||
| => false; | |||
| public object to_proto() | |||
| { | |||
| throw new NotImplementedException(); | |||
| @@ -14,6 +14,12 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Linq; | |||
| using System.Collections.Generic; | |||
| using util = Tensorflow.control_flow_util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Operations.ControlFlows | |||
| { | |||
| /// <summary> | |||
| @@ -21,6 +27,7 @@ namespace Tensorflow.Operations.ControlFlows | |||
| /// </summary> | |||
| public class ControlFlowState | |||
| { | |||
| Dictionary<ControlFlowContext, GradLoopState> _map; | |||
| //class ControlFlowState(object): | |||
| // """Maintain the mapping from the loops to their grad states.""" | |||
| @@ -40,51 +47,67 @@ namespace Tensorflow.Operations.ControlFlows | |||
| // return self._map.get(forward_ctxt) | |||
| // return None | |||
| // def ProcessUnusedLoopExits(self, pending_count, to_ops_set): | |||
| // """Process all the "unused" loop exits. | |||
| // The "unused" exits of the loops are added to `unused_exits`. An exit is | |||
| // unused if its pending_count is 0. If there is an exit with real gradient, | |||
| // all these deferred exits will enter the backprop loop with zero gradient. | |||
| // Otherwise, they will enter the backprop loop with None. As an example, | |||
| // people often write: | |||
| // ```python | |||
| // v1, _ = tf.while_loop(p, b, [x1, x2]) | |||
| // result = gradients(v1, x1) | |||
| // ``` | |||
| // The exit node for x2 is not included by the betweenness analysis. But we | |||
| // need to backprop x2 if x2 is involved in computing v1. | |||
| // Args: | |||
| // pending_count: The number of backprop inputs for every op. | |||
| // to_ops_set: The set of ops for ys in gradients(ys, xs) | |||
| // Returns: | |||
| // The set of unused loop exits that we know at this point we need | |||
| // to backprop. | |||
| // """ | |||
| // loop_exits = [] | |||
| // for grad_state in self._map.values(): | |||
| // for y in grad_state.forward_loop_exits: | |||
| // if pending_count[y.op] == 0: | |||
| // grad_state.pending_exits_count -= 1 | |||
| // if y.op not in to_ops_set: | |||
| // grad_state.unused_exits.append(y) | |||
| // if grad_state.pending_exits_count == 0: | |||
| // loop_exits.extend(grad_state.unused_exits) | |||
| // # Need to include Enters in backprop for higher-order gradients. | |||
| // for y in grad_state.forward_context.loop_enters: | |||
| // if pending_count[y.op] == 0: | |||
| // pending_count[y.op] = 1 | |||
| // return loop_exits | |||
| // def EnterGradWhileContext(self, op, before): | |||
| // """Enter the WhileContext for gradient computation.""" | |||
| // grad_state = self.GetGradState(op, before) | |||
| // if grad_state: | |||
| // grad_state.grad_context.Enter() | |||
| public ControlFlowState() | |||
| { | |||
| _map = new Dictionary<ControlFlowContext, GradLoopState>(); | |||
| } | |||
| /// <summary> | |||
| /// Return the grad state for this op if it's in a forward loop context. | |||
| /// </summary> | |||
| /// <param name="op"></param> | |||
| /// <param name="before"></param> | |||
| /// <returns></returns> | |||
| public GradLoopState GetGradState(Operation op, bool before) | |||
| { | |||
| ControlFlowContext forward_ctxt = null; | |||
| if (before && util.IsLoopExit(op)) | |||
| { | |||
| forward_ctxt = op._get_control_flow_context(); | |||
| forward_ctxt = forward_ctxt.outer_context; | |||
| if (forward_ctxt != null) | |||
| forward_ctxt = forward_ctxt.GetWhileContext(); | |||
| } | |||
| else | |||
| forward_ctxt = util.GetWhileContext(op); | |||
| if (forward_ctxt != null) | |||
| return _map.get(forward_ctxt); | |||
| return null; | |||
| } | |||
| public Tensor[] ProcessUnusedLoopExits(Dictionary<string, int> pending_count, List<Operation> to_ops_set) | |||
| { | |||
| var loop_exits = new List<Tensor>(); | |||
| foreach(var grad_state in _map.Values) | |||
| { | |||
| foreach(var y in grad_state.forward_loop_exits) | |||
| { | |||
| if(!pending_count.ContainsKey(y.op.name)) | |||
| { | |||
| grad_state.pending_exits_count -= 1; | |||
| if (!to_ops_set.Contains(y.op)) | |||
| grad_state.unused_exits.append(y); | |||
| if (grad_state.pending_exits_count == 0) | |||
| loop_exits.extend(grad_state.unused_exits); | |||
| } | |||
| } | |||
| foreach(var y in grad_state.forward_context.loop_enters) | |||
| { | |||
| if (!pending_count.ContainsKey(y.op.name)) | |||
| pending_count[y.op.name] = 1; | |||
| } | |||
| } | |||
| return loop_exits.ToArray(); | |||
| } | |||
| public void EnterGradWhileContext(Operation op, bool before) | |||
| { | |||
| var grad_state = GetGradState(op, before); | |||
| if (grad_state != null) | |||
| grad_state.grad_context.Enter(); | |||
| } | |||
| // def ExitGradWhileContext(self, op, before): | |||
| // """Exit the WhileContext for gradient computation.""" | |||
| @@ -118,6 +141,32 @@ namespace Tensorflow.Operations.ControlFlows | |||
| // if loop_exit.op not in between_ops: | |||
| // between_ops.add(loop_exit.op) | |||
| // between_op_list.append(loop_exit.op) | |||
| public void AddWhileContext(Operation op, List<Operation> between_op_list, List<Operation> between_ops) | |||
| { | |||
| var forward_ctxt = op.GetWhileContext(); | |||
| var grad_state = _map.ContainsKey(forward_ctxt) ? _map[forward_ctxt] : null; | |||
| if(grad_state == null) | |||
| { | |||
| GradLoopState outer_grad_state = null; | |||
| var outer_forward_ctxt = forward_ctxt.outer_context; | |||
| if (outer_forward_ctxt != null) | |||
| outer_forward_ctxt = outer_forward_ctxt.GetWhileContext(); | |||
| if (outer_forward_ctxt != null) | |||
| outer_grad_state = _map[outer_forward_ctxt]; | |||
| grad_state = new GradLoopState(forward_ctxt, outer_grad_state); | |||
| _map[forward_ctxt] = grad_state; | |||
| // We need to include all exits of a loop for backprop. | |||
| foreach (var loop_exit in grad_state.forward_loop_exits) | |||
| { | |||
| if(!between_ops.Contains(loop_exit.op)) | |||
| { | |||
| between_ops.add(loop_exit.op); | |||
| between_op_list.append(loop_exit.op); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // def ZerosLikeForExit(self, val): | |||
| // """Create zeros_like gradient for a loop exit. | |||
| @@ -174,70 +223,69 @@ namespace Tensorflow.Operations.ControlFlows | |||
| // result = array_ops.zeros_like(val, optimize=False) | |||
| // return result | |||
| // def ZerosLike(self, op, index): | |||
| // """Create zeros_like for the specified output of an op. | |||
| // If op is in a while loop that is part of gradients(), this method | |||
| // must be called in its grad loop context. | |||
| // Args: | |||
| // op: A tensorflow operation. | |||
| // index: the index for a specific output of the op. | |||
| // Returns: | |||
| // A zero tensor of the same shape of op.outputs[index]. | |||
| // """ | |||
| // if util.IsLoopSwitch(op): | |||
| // return None | |||
| // if op.graph._building_function: # pylint: disable=protected-access | |||
| // # The optimization here is tricky to apply to functions | |||
| // return array_ops.zeros_like(op.outputs[index]) | |||
| // dead_branch = util.IsSwitch(op) | |||
| // forward_ctxt = _GetWhileContext(op) | |||
| // grad_state = self._map.get(forward_ctxt) | |||
| // if grad_state is None: | |||
| // # op is not in a while loop that is part of gradients(). | |||
| // return ZerosLikeOutsideLoop(op, index) | |||
| // op_ctxt = op._get_control_flow_context() | |||
| // val = ops.convert_to_tensor(op.outputs[index], name="tensor") | |||
| // shape = val.get_shape() | |||
| // if shape.is_fully_defined(): | |||
| // # If the shape is known statically, just create a zero tensor with | |||
| // # the right shape in the grad loop context. | |||
| // result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype) | |||
| // if dead_branch: | |||
| // # op is a cond switch. Guard the zero tensor with a switch. | |||
| // pred = grad_state.history_map.get(op_ctxt.pred.name) | |||
| // branch = op_ctxt.branch | |||
| // result = _SwitchRefOrTensor(result, pred)[1 - branch] | |||
| // else: | |||
| // # Unknown shape so keep a history of the shape at runtime. | |||
| // if dead_branch: | |||
| // # Need to add a special switch to guard the value. | |||
| // pred = op_ctxt.pred | |||
| // branch = op_ctxt.branch | |||
| // op_ctxt.outer_context.Enter() | |||
| // val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch] | |||
| // zeros_shape = array_ops.shape_internal(val, optimize=False) | |||
| // op_ctxt.outer_context.Exit() | |||
| // val.op._set_control_flow_context(op_ctxt) | |||
| // zeros_shape.op._set_control_flow_context(op_ctxt) | |||
| // else: | |||
| // op_ctxt.Enter() | |||
| // zeros_shape = array_ops.shape_internal(val, optimize=False) | |||
| // op_ctxt.Exit() | |||
| // # Add forward accumulator for shape. | |||
| // grad_state.grad_context.Exit() | |||
| // history_zeros_shape = grad_state.AddForwardAccumulator( | |||
| // zeros_shape, dead_branch=dead_branch) | |||
| // grad_state.grad_context.Enter() | |||
| // # Create a zero tensor with the right shape. | |||
| // shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape, | |||
| // zeros_shape, dead_branch) | |||
| // result = array_ops.zeros(shape, val.dtype) | |||
| // return result | |||
| public Tensor ZerosLike(Operation op, int index) | |||
| { | |||
| if (util.IsLoopSwitch(op)) | |||
| return null; | |||
| if (op.graph.building_function) | |||
| return array_ops.zeros_like(op.outputs[index]); | |||
| var dead_branch = util.IsSwitch(op); | |||
| var forward_ctxt = util.GetWhileContext(op); | |||
| var grad_state = _map.get(forward_ctxt); | |||
| // op is not in a while loop that is part of gradients(). | |||
| if (grad_state == null) | |||
| return ZerosLikeOutsideLoop(op, index); | |||
| throw new NotImplementedException("ZerosLike"); | |||
| } | |||
| public Tensor ZerosLikeOutsideLoop(Operation op, int index) | |||
| { | |||
| var val = op.outputs[index]; | |||
| if (!util.IsSwitch(op)) | |||
| { | |||
| if (val.dtype == dtypes.resource) | |||
| throw new NotImplementedException("ZerosLikeOutsideLoop"); | |||
| /*return array_ops.zeros( | |||
| gen_resource_variable_ops.variable_shape(val), | |||
| dtype: default_gradient.get_zeros_dtype(val));*/ | |||
| return array_ops.zeros_like(val, optimize: false); | |||
| } | |||
| else | |||
| throw new NotImplementedException("ZerosLikeOutsideLoop"); | |||
| } | |||
| /// <summary> | |||
| /// Create zeros_like gradient for a loop exit. | |||
| /// </summary> | |||
| /// <param name="val"></param> | |||
| /// <returns></returns> | |||
| public Tensor ZerosLikeForExit(Tensor val) | |||
| { | |||
| Tensor result = null; | |||
| var val_shape = val.TensorShape; | |||
| var forward_ctxt = val.op._get_control_flow_context(); | |||
| var outer_forward_ctxt = forward_ctxt.outer_context; | |||
| if (outer_forward_ctxt != null) | |||
| outer_forward_ctxt = outer_forward_ctxt.GetWhileContext(); | |||
| GradLoopState outer_grad_state = null; | |||
| if (outer_forward_ctxt != null) | |||
| outer_grad_state = _map.get(outer_forward_ctxt); | |||
| // This is a nested loop. | |||
| if (outer_grad_state != null) | |||
| { | |||
| throw new NotImplementedException("ZerosLikeForExit"); | |||
| } | |||
| else | |||
| { | |||
| // If the shape is known statically, just create a zero tensor | |||
| // with the right shape. | |||
| if (val_shape.is_fully_defined()) | |||
| result = array_ops.zeros(val_shape.dims, val.dtype); | |||
| else | |||
| result = array_ops.zeros_like(val, optimize: false); | |||
| } | |||
| return result; | |||
| } | |||
| // def PostProcessing(self): | |||
| // """Perform postprocessing at the end of gradients(). | |||
| @@ -16,41 +16,16 @@ | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Operations.ControlFlows | |||
| { | |||
| /// <summary> | |||
| /// The state used for constructing the gradient graph for a while loop. | |||
| /// </summary> | |||
| public class GradLoopState | |||
| { | |||
| //class GradLoopState(object): | |||
| // """The state used for constructing the gradient graph for a while loop. | |||
| // We create a GradLoopState for each while loop in forward and its | |||
| // corresponding while loop in backprop. This gives us access to both | |||
| // the forward and the backprop WhileContexts. | |||
| // During the construction of gradient graph, any time when we detect | |||
| // a forward value that is needed for backprop, we create a history | |||
| // accumulator and add it to `history_map`. Any time when we backprop | |||
| // a loop switch op (in _SwitchGrad), we add the grad merge op in | |||
| // `switch_map`. | |||
| // """ | |||
| // def __init__(self, forward_ctxt, outer_grad_state): | |||
| // # The grad loop state for the outer while loop. | |||
| // self._outer_grad_state = None | |||
| // # The while loop context for forward. | |||
| // self._forward_context = None | |||
| // # The loop counter added by AddForwardLoopCounter. It is the value | |||
| // # of the loop counter for the next iteration. | |||
| // self._forward_index = None | |||
| // # A sync op for forward. | |||
| // self._forward_sync = None | |||
| // # The while loop context for backprop. | |||
| private WhileContext _grad_context = null; | |||
| public WhileContext grad_context => _grad_context; | |||
| @@ -65,156 +40,91 @@ namespace Tensorflow.Operations.ControlFlows | |||
| // # Information needed by backprop. | |||
| private Hashtable _history_map = new Hashtable(); | |||
| public Hashtable history_map => _history_map; | |||
| private Hashtable _switch_map = new Hashtable(); | |||
| public Hashtable switch_map => _switch_map; | |||
| // self._unused_exits = [] | |||
| // self._deferred_exits = [] | |||
| // self._forward_loop_exits = list(forward_ctxt.loop_exits) | |||
| // self._pending_exits_count = len(forward_ctxt.loop_exits) | |||
| // self._outer_grad_state = outer_grad_state | |||
| // if outer_grad_state: | |||
| // outer_forward_ctxt = outer_grad_state.forward_context | |||
| // else: | |||
| // if not hasattr(forward_ctxt, "outer_context"): | |||
| // raise ValueError("Failed to call gradients on a while loop without" | |||
| // "properly serializing graph via MetaGraphDef") | |||
| // outer_forward_ctxt = forward_ctxt.outer_context | |||
| // # Add the forward loop counter. | |||
| // with forward_ctxt._graph.as_default(): # pylint: disable=protected-access | |||
| // if outer_forward_ctxt: | |||
| // outer_forward_ctxt.Enter() | |||
| // cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) | |||
| // if outer_forward_ctxt: | |||
| // outer_forward_ctxt.Exit() | |||
| // self._forward_context = forward_ctxt | |||
| // self._forward_index = forward_index | |||
| // # Add the backprop WhileContext, and the backprop loop counter. | |||
| // if outer_grad_state: | |||
| // # This is a nested loop. Remember the iteration counts for each | |||
| // # execution of this inner loop. | |||
| // outer_forward_ctxt.AddName(cnt.name) | |||
| // history_cnt = outer_grad_state.AddForwardAccumulator(cnt) | |||
| // outer_grad_ctxt = outer_grad_state.grad_context | |||
| // outer_grad_ctxt.Enter() | |||
| // self._grad_context = WhileContext( | |||
| // maximum_iterations=forward_ctxt.maximum_iterations, | |||
| // parallel_iterations=forward_ctxt.parallel_iterations, | |||
| // back_prop=forward_ctxt.back_prop, | |||
| // swap_memory=forward_ctxt.swap_memory, | |||
| // name=forward_ctxt.name, | |||
| // grad_state=self) | |||
| // real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt) | |||
| // self._grad_index = self._grad_context.AddBackpropLoopCounter( | |||
| // real_cnt, outer_grad_state) | |||
| // outer_grad_ctxt.Exit() | |||
| // else: | |||
| // if outer_forward_ctxt: | |||
| // outer_forward_ctxt.Enter() | |||
| // self._grad_context = WhileContext( | |||
| // maximum_iterations=forward_ctxt.maximum_iterations, | |||
| // parallel_iterations=forward_ctxt.parallel_iterations, | |||
| // back_prop=forward_ctxt.back_prop, | |||
| // swap_memory=forward_ctxt.swap_memory, | |||
| // name=forward_ctxt.name, | |||
| // grad_state=self) | |||
| // self._grad_index = self._grad_context.AddBackpropLoopCounter( | |||
| // cnt, outer_grad_state) | |||
| // if outer_forward_ctxt: | |||
| // outer_forward_ctxt.Exit() | |||
| // @property | |||
| // def outer_grad_state(self): | |||
| // """The grad loop state for outer loop.""" | |||
| // return self._outer_grad_state | |||
| // @property | |||
| // def forward_context(self): | |||
| // """The while loop context for forward.""" | |||
| // return self._forward_context | |||
| // @property | |||
| // def forward_index(self): | |||
| // """The loop index of forward loop.""" | |||
| // return self._forward_index | |||
| // @property | |||
| // def forward_sync(self): | |||
| // """A control trigger node for synchronization in the forward loop. | |||
| // One main use is to keep the push ops of a stack executed in the | |||
| // iteration order. | |||
| // """ | |||
| // if self._forward_sync is None: | |||
| // with ops.control_dependencies(None): | |||
| // self._forward_sync = control_trigger(name="f_sync") | |||
| // self._forward_sync._set_control_flow_context(self._forward_context) | |||
| // self._forward_index.op._add_control_input(self._forward_sync) | |||
| // return self._forward_sync | |||
| // @property | |||
| // def grad_context(self): | |||
| // """The corresponding WhileContext for gradient.""" | |||
| // return self._grad_context | |||
| // @property | |||
| // def grad_index(self): | |||
| // """The loop index of backprop loop.""" | |||
| // return self._grad_index | |||
| // @property | |||
| // def grad_sync(self): | |||
| // """A control trigger node for synchronization in the grad loop. | |||
| // One main use is to keep the pop ops of a stack executed in the | |||
| // iteration order. | |||
| // """ | |||
| // if self._grad_sync is None: | |||
| // with ops.control_dependencies(None): | |||
| // self._grad_sync = control_trigger(name="b_sync") | |||
| // self._grad_sync._set_control_flow_context(self._grad_context) | |||
| // self._grad_index.op._add_control_input(self._grad_sync) | |||
| // if self._grad_context.outer_context: | |||
| // self._grad_context.outer_context.AddInnerOp(self._grad_sync) | |||
| // return self._grad_sync | |||
| // @property | |||
| // def history_map(self): | |||
| // """The map that records all the tensors needed for backprop.""" | |||
| // return self._history_map | |||
| // @property | |||
| // def switch_map(self): | |||
| // """The map that records all the Switch ops for the while loop.""" | |||
| // return self._switch_map | |||
| // @property | |||
| // def unused_exits(self): | |||
| // """The list of "unused" exits.""" | |||
| // return self._unused_exits | |||
| // @property | |||
| // def deferred_exits(self): | |||
| // """The list of "deferred" exits.""" | |||
| // return self._deferred_exits | |||
| // @property | |||
| // def forward_loop_exits(self): | |||
| // """The list of exits of the forward loop.""" | |||
| // return self._forward_loop_exits | |||
| // @property | |||
| // def pending_exits_count(self): | |||
| // """The number of exits we expect to see but haven't.""" | |||
| // return self._pending_exits_count | |||
| // @pending_exits_count.setter | |||
| // def pending_exits_count(self, cnt): | |||
| // """Set the pending count to cnt.""" | |||
| // self._pending_exits_count = cnt | |||
| Dictionary<Operation, Tensor> _switch_map = new Dictionary<Operation, Tensor>(); | |||
| public Dictionary<Operation, Tensor> switch_map => _switch_map; | |||
| /// <summary> | |||
| /// The while loop context for forward. | |||
| /// </summary> | |||
| WhileContext _forward_context; | |||
| public WhileContext forward_context => _forward_context; | |||
| /// <summary> | |||
| /// The grad loop state for the outer while loop. | |||
| /// </summary> | |||
| GradLoopState _outer_grad_state; | |||
| public GradLoopState outer_grad_state => _outer_grad_state; | |||
| Tensor _forward_index; | |||
| Tensor _grad_index; | |||
| Tensor[] _forward_loop_exits; | |||
| /// <summary> | |||
| /// The list of exits of the forward loop. | |||
| /// </summary> | |||
| public Tensor[] forward_loop_exits => _forward_loop_exits; | |||
| List<Tensor> _deferred_exits; | |||
| public List<Tensor> deferred_exits => _deferred_exits; | |||
| List<Tensor> _unused_exits; | |||
| public List<Tensor> unused_exits => _unused_exits; | |||
| /// <summary> | |||
| /// The number of exits we expect to see but haven't. | |||
| /// </summary> | |||
| public int pending_exits_count { get; set; } | |||
| public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_) | |||
| { | |||
| // Information needed by backprop. | |||
| _unused_exits = new List<Tensor>(); | |||
| _deferred_exits = new List<Tensor>(); | |||
| _forward_loop_exits = list(forward_ctxt.loop_exits); | |||
| pending_exits_count = len(forward_ctxt.loop_exits); | |||
| _outer_grad_state = outer_grad_state_; | |||
| ControlFlowContext outer_forward_ctxt = null; | |||
| if (outer_grad_state_ != null) | |||
| outer_forward_ctxt = outer_grad_state_.forward_context; | |||
| // Add the forward loop counter. | |||
| // with forward_ctxt._graph.as_default(): | |||
| Tensor cnt, forward_index; | |||
| { | |||
| if (outer_forward_ctxt != null) | |||
| outer_forward_ctxt.Enter(); | |||
| (cnt, forward_index) = forward_ctxt.AddForwardLoopCounter(outer_grad_state); | |||
| if (outer_forward_ctxt != null) | |||
| outer_forward_ctxt.Exit(); | |||
| } | |||
| _forward_context = forward_ctxt; | |||
| _forward_index = forward_index; | |||
| // Add the backprop WhileContext, and the backprop loop counter. | |||
| if (outer_grad_state != null) | |||
| { | |||
| // This is a nested loop. Remember the iteration counts for each | |||
| // execution of this inner loop. | |||
| throw new NotImplementedException("GradLoopState"); | |||
| } | |||
| else | |||
| { | |||
| if (outer_forward_ctxt != null) | |||
| outer_forward_ctxt.Enter(); | |||
| _grad_context = new WhileContext( | |||
| maximum_iterations: forward_ctxt.maximum_iterations, | |||
| parallel_iterations: forward_ctxt.parallel_iterations, | |||
| back_prop: forward_ctxt.back_prop, | |||
| swap_memory: forward_ctxt.swap_memory, | |||
| name: forward_ctxt.name, | |||
| grad_state: this); | |||
| _grad_index = _grad_context.AddBackpropLoopCounter(cnt, outer_grad_state); | |||
| if (outer_forward_ctxt != null) | |||
| outer_forward_ctxt.Exit(); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Add an accumulator for each forward tensor that is needed in backprop. | |||
| @@ -0,0 +1,36 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| public class MergeOutput | |||
| { | |||
| Tensor output; | |||
| Tensor value_index; | |||
| public MergeOutput(Tensor[] values) | |||
| { | |||
| output = values[0]; | |||
| value_index = values[1]; | |||
| } | |||
| public Tensor this[int idx] | |||
| { | |||
| get | |||
| { | |||
| switch(idx) | |||
| { | |||
| case 0: | |||
| return output; | |||
| case 1: | |||
| return value_index; | |||
| default: | |||
| return null; | |||
| } | |||
| } | |||
| } | |||
| public static implicit operator Tensor(MergeOutput merge) | |||
| => merge.output; | |||
| } | |||
| } | |||
| @@ -32,12 +32,17 @@ namespace Tensorflow.Operations | |||
| bool _back_prop=true; | |||
| GradLoopState _grad_state =null; | |||
| Tensor _maximum_iterations; | |||
| public Tensor maximum_iterations => _maximum_iterations; | |||
| int _parallel_iterations; | |||
| public int parallel_iterations => _parallel_iterations; | |||
| bool _swap_memory; | |||
| public bool swap_memory => _swap_memory; | |||
| Tensor _pivot_for_pred; | |||
| Tensor _pivot_for_body; | |||
| List<Tensor> _loop_exits; | |||
| public List<Tensor> loop_exits => _loop_exits; | |||
| List<Tensor> _loop_enters; | |||
| public List<Tensor> loop_enters => _loop_enters; | |||
| Graph _graph; | |||
| public override GradLoopState grad_state => _grad_state; | |||
| public override bool back_prop => _back_prop; | |||
| @@ -109,7 +114,7 @@ namespace Tensorflow.Operations | |||
| /// <summary> | |||
| /// Add the loop termination condition and body to the graph. | |||
| /// </summary> | |||
| internal Tensor[] BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred, | |||
| internal LoopVar<TItem> BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred, | |||
| Func<LoopVar<TItem>, LoopVar<TItem>> body, | |||
| LoopVar<TItem> loop_vars, | |||
| TensorShape[] shape_invariants, | |||
| @@ -132,14 +137,16 @@ namespace Tensorflow.Operations | |||
| pred, body, original_loop_vars, loop_vars_tensors, shape_invariants); | |||
| Exit(); | |||
| var flat_result = original_body_result; | |||
| var flat_result = nest.flatten2(original_body_result) | |||
| .Select(x => x as ITensorOrTensorArray) | |||
| .ToArray(); | |||
| var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars); | |||
| var packed_exit_vars = nest.pack_sequence_as( | |||
| var packed_exit_vars = nest.pack_sequence_as2( | |||
| structure: original_body_result, | |||
| flat_sequence: exit_vars_with_tensor_arrays); | |||
| return packed_exit_vars as Tensor[]; | |||
| return packed_exit_vars; | |||
| } | |||
| private Tensor _convert_tensorarray_to_flow(object tensor_or_tensor_array) | |||
| @@ -167,7 +174,7 @@ namespace Tensorflow.Operations | |||
| /// <param name="loop_vars"></param> | |||
| /// <param name="shape_invariants"></param> | |||
| /// <returns></returns> | |||
| private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred, | |||
| private (LoopVar<TItem>, Tensor[]) _BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred, | |||
| Func<LoopVar<TItem>, LoopVar<TItem>> body, | |||
| LoopVar<TItem> original_loop_vars, | |||
| Tensor[] loop_vars, | |||
| @@ -221,6 +228,7 @@ namespace Tensorflow.Operations | |||
| var merge_vars = enter_vars | |||
| .Select(x => merge(new[] { x, x })) | |||
| .Select(m => (Tensor)m) | |||
| .ToArray(); | |||
| _pivot_for_pred = merge_vars[0]; | |||
| @@ -250,13 +258,15 @@ namespace Tensorflow.Operations | |||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
| // Store body_result to keep track of TensorArrays returned by body | |||
| var original_body_result = new[] { body_result }; | |||
| var original_body_result = body_result; | |||
| // Convert TensorArrays returned by body into their flow variables | |||
| var result = new[] { body_result }; | |||
| var result = nest.flatten2(body_result) | |||
| .Select(x => _convert_tensorarray_to_flow(x)) | |||
| .ToArray(); | |||
| // result = ops.convert_n_to_tensor_or_composite(result); | |||
| var next_vars = new List<Tensor>(); | |||
| //foreach (var (m, v) in zip(merge_vars, result)) | |||
| //next_vars.Add(_AddNextAndBackEdge(m, v)); | |||
| foreach (var (m, v) in zip(merge_vars, result)) | |||
| next_vars.Add(_AddNextAndBackEdge(m, v)); | |||
| // Add the exit ops. | |||
| var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); | |||
| @@ -264,7 +274,7 @@ namespace Tensorflow.Operations | |||
| // Exit the loop. | |||
| // ExitResult(exit_vars); | |||
| return (null, exit_vars.ToArray()); | |||
| return (original_body_result, exit_vars.ToArray()); | |||
| } | |||
| private void _FixControlInputsAndContext(Tensor[] enters) | |||
| @@ -282,7 +292,18 @@ namespace Tensorflow.Operations | |||
| var keep_as_control_input = true; | |||
| var op_ctxt = control_flow_util.GetOutputContext(op); | |||
| var outer_ctxt = outer_context; | |||
| throw new NotImplementedException(""); | |||
| var outer_while_context = outer_ctxt == null ? null : outer_ctxt.GetWhileContext(); | |||
| while (outer_ctxt != op_ctxt) | |||
| { | |||
| if (outer_ctxt == null || outer_ctxt == outer_while_context) | |||
| { | |||
| keep_as_control_input = false; | |||
| break; | |||
| } | |||
| outer_ctxt = outer_ctxt.outer_context; | |||
| } | |||
| if (keep_as_control_input) | |||
| outer_control_inputs.append(op); | |||
| } | |||
| // op for op in control_inputs if self._IsInOuterContext(op) | |||
| /*var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) | |||
| @@ -307,10 +328,21 @@ namespace Tensorflow.Operations | |||
| protected override void _AddOpInternal(Operation op) | |||
| { | |||
| if (op.name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad") | |||
| { | |||
| } | |||
| Operation[] external_inputs = new Operation[0]; | |||
| Operation[] control_inputs = new Operation[0]; | |||
| if (op.inputs.Length == 0) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| // Remove any external control dependency on this op | |||
| (control_inputs, external_inputs) = _RemoveExternalControlEdges(op); | |||
| if (control_inputs.Length == 0) | |||
| op._add_control_input(GetControlPivot().op); | |||
| foreach (var x in op.outputs) | |||
| _values.Add(x.name); | |||
| } | |||
| else | |||
| { | |||
| @@ -378,6 +410,93 @@ namespace Tensorflow.Operations | |||
| _AddOpInternal(op); | |||
| } | |||
| /// <summary> | |||
| /// Adds a loop that counts the number of iterations. | |||
| /// </summary> | |||
| /// <param name="outer_grad_state">The outer grad state. None if not nested.</param> | |||
| /// <returns>The number of iterations taken by the forward loop and the loop index.</returns> | |||
| public (Tensor, Tensor) AddForwardLoopCounter(GradLoopState outer_grad_state) | |||
| { | |||
| var n = constant_op.constant(0, name: "f_count"); | |||
| if (outer_grad_state != null) | |||
| throw new NotImplementedException("AddForwardLoopCounter"); | |||
| Enter(); | |||
| AddName(n.name); | |||
| var enter_n = _Enter(n, | |||
| _name, | |||
| is_constant: false, | |||
| parallel_iterations: _parallel_iterations, | |||
| name: "f_count"); | |||
| _loop_enters.Add(enter_n); | |||
| var m1 = merge(new[] { enter_n, enter_n }); | |||
| var merge_n = m1[0]; | |||
| var switch_n = @switch (merge_n, _pivot); | |||
| var index = math_ops.add(switch_n[1], 1); | |||
| var next_n = _NextIteration(index); | |||
| merge_n.op._update_input(1, next_n); | |||
| var total_iterations = exit(switch_n[0], name: "f_count"); | |||
| loop_exits.append(total_iterations); | |||
| ExitResult(new[] { total_iterations }); | |||
| Exit(); | |||
| return (total_iterations, next_n); | |||
| } | |||
| /// <summary> | |||
| /// Add the backprop loop that controls the iterations. | |||
| /// </summary> | |||
| /// <param name="count">The number of iterations for backprop.</param> | |||
| /// <param name="outer_grad_state">The outer grad state. None if not nested.</param> | |||
| /// <returns>The loop index.</returns> | |||
| public Tensor AddBackpropLoopCounter(Tensor count, GradLoopState outer_grad_state) | |||
| { | |||
| Tensor one = null; | |||
| var in_separate_functions = count.graph != ops.get_default_graph(); | |||
| if (in_separate_functions) | |||
| // Brings the count into this graph | |||
| count = array_ops.identity(count); | |||
| else | |||
| one = constant_op.constant(1, name: "b_count"); | |||
| Enter(); | |||
| AddName(count.name); | |||
| var enter_count = _Enter( | |||
| count, | |||
| _name, | |||
| is_constant: false, | |||
| parallel_iterations: _parallel_iterations, | |||
| name: "b_count"); | |||
| loop_enters.append(enter_count); | |||
| var merge_count = merge(new[] { enter_count, enter_count })[0]; | |||
| _pivot_for_pred = merge_count; | |||
| if (in_separate_functions) | |||
| one = constant_op.constant(1, name: "b_count"); | |||
| var pred = math_ops.greater_equal(merge_count, one); | |||
| _pivot = gen_control_flow_ops.loop_cond(pred, name: "b_count"); | |||
| var switch_count = @switch(merge_count, _pivot); | |||
| var index = math_ops.subtract(switch_count[1], one); | |||
| _pivot_for_body = index; | |||
| var next_count = _NextIteration(index); | |||
| merge_count.op._update_input(1, next_count); | |||
| var final_zero = exit(switch_count[0], name: "b_count"); | |||
| loop_exits.append(final_zero); | |||
| // Force the stack pops of i-th execution of an inner loop to be ordered | |||
| // before the pops of (i+1)-th execution of the same inner loop. | |||
| if (outer_grad_state != null) | |||
| throw new NotImplementedException("outer_grad_state"); | |||
| //outer_grad_state.grad_sync._add_control_input(final_zero.op); | |||
| ExitResult(new[] { final_zero }); | |||
| Exit(); | |||
| return next_count; | |||
| } | |||
| /// <summary> | |||
| /// Add `val` to the current context and its outer context recursively. | |||
| /// </summary> | |||
| @@ -401,17 +520,27 @@ namespace Tensorflow.Operations | |||
| grad_ctxt = grad_ctxt.GetWhileContext(); | |||
| if (grad_ctxt.grad_state != null) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| var forward_ctxt = val.op.GetWhileContext(); | |||
| if (control_flow_util.IsLoopExit(val.op)) | |||
| { | |||
| forward_ctxt = forward_ctxt.outer_context as WhileContext; | |||
| if (forward_ctxt != null) | |||
| forward_ctxt = forward_ctxt.GetWhileContext(); | |||
| throw new NotImplementedException("control_flow_util.IsLoopExit"); | |||
| } | |||
| if(forward_ctxt == grad_ctxt.grad_state.forward_context) | |||
| { | |||
| throw new NotImplementedException("forward_ctxt == grad_ctxt.grad_state.forward_context"); | |||
| /*real_val = grad_ctxt.grad_state.GetRealValue(val); | |||
| _external_values[val.name] = real_val; | |||
| return real_val;*/ | |||
| } | |||
| } | |||
| } | |||
| if (_outer_context != null) | |||
| result = _outer_context.AddValue(val); | |||
| if (tf.get_default_graph()._nodes_by_name.Count >= 83) | |||
| { | |||
| } | |||
| // Create an Enter to make `result` known to this loop context. | |||
| Tensor enter = null; | |||
| tf_with(ops.control_dependencies(null), delegate | |||
| @@ -443,6 +572,9 @@ namespace Tensorflow.Operations | |||
| return result; | |||
| } | |||
| public override bool IsWhileContext() | |||
| => true; | |||
| public override WhileContext GetWhileContext() | |||
| { | |||
| return this; | |||