| @@ -20,6 +20,7 @@ using System.Linq; | |||||
| using Tensorflow.Operations.ControlFlows; | using Tensorflow.Operations.ControlFlows; | ||||
| using static Tensorflow.ControlFlowContextDef; | using static Tensorflow.ControlFlowContextDef; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using util = Tensorflow.control_flow_util; | |||||
| namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
| { | { | ||||
| @@ -146,6 +147,14 @@ namespace Tensorflow.Operations | |||||
| graph._set_control_flow_context(last_context); | graph._set_control_flow_context(last_context); | ||||
| } | } | ||||
| public void ExitResult(Tensor[] result) | |||||
| { | |||||
| if(_outer_context != null) | |||||
| { | |||||
| throw new NotImplementedException("ExitResult"); | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Add `op` to the current context. | /// Add `op` to the current context. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -172,6 +181,11 @@ namespace Tensorflow.Operations | |||||
| return null; | return null; | ||||
| } | } | ||||
| public void AddName(string name) | |||||
| { | |||||
| _values.Add(name); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Notifies a scope about an operator added to an inner scope. | /// Notifies a scope about an operator added to an inner scope. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -246,9 +260,11 @@ namespace Tensorflow.Operations | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| foreach(Tensor x in op.control_inputs) | |||||
| foreach(Operation x in op.control_inputs) | |||||
| { | { | ||||
| throw new NotImplementedException(""); | |||||
| var ctxt = util.GetOutputContext(x); | |||||
| if (ctxt != null && ctxt.GetWhileContext() == while_ctxt) | |||||
| internal_control_inputs.append(x); | |||||
| } | } | ||||
| } | } | ||||
| @@ -288,6 +304,14 @@ namespace Tensorflow.Operations | |||||
| throw new NotImplementedException($"Unknown ControlFlowContextDef field: {context_def.CtxtCase}"); | throw new NotImplementedException($"Unknown ControlFlowContextDef field: {context_def.CtxtCase}"); | ||||
| } | } | ||||
| public virtual bool IsWhileContext() | |||||
| { | |||||
| throw new NotImplementedException("IsWhileContext"); | |||||
| } | |||||
| public virtual bool IsCondContext() | |||||
| => false; | |||||
| public object to_proto() | public object to_proto() | ||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| @@ -14,6 +14,12 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | |||||
| using System.Linq; | |||||
| using System.Collections.Generic; | |||||
| using util = Tensorflow.control_flow_util; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Operations.ControlFlows | namespace Tensorflow.Operations.ControlFlows | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| @@ -21,6 +27,7 @@ namespace Tensorflow.Operations.ControlFlows | |||||
| /// </summary> | /// </summary> | ||||
| public class ControlFlowState | public class ControlFlowState | ||||
| { | { | ||||
| Dictionary<ControlFlowContext, GradLoopState> _map; | |||||
| //class ControlFlowState(object): | //class ControlFlowState(object): | ||||
| // """Maintain the mapping from the loops to their grad states.""" | // """Maintain the mapping from the loops to their grad states.""" | ||||
| @@ -40,51 +47,67 @@ namespace Tensorflow.Operations.ControlFlows | |||||
| // return self._map.get(forward_ctxt) | // return self._map.get(forward_ctxt) | ||||
| // return None | // return None | ||||
| // def ProcessUnusedLoopExits(self, pending_count, to_ops_set): | |||||
| // """Process all the "unused" loop exits. | |||||
| // The "unused" exits of the loops are added to `unused_exits`. An exit is | |||||
| // unused if its pending_count is 0. If there is an exit with real gradient, | |||||
| // all these deferred exits will enter the backprop loop with zero gradient. | |||||
| // Otherwise, they will enter the backprop loop with None. As an example, | |||||
| // people often write: | |||||
| // ```python | |||||
| // v1, _ = tf.while_loop(p, b, [x1, x2]) | |||||
| // result = gradients(v1, x1) | |||||
| // ``` | |||||
| // The exit node for x2 is not included by the betweenness analysis. But we | |||||
| // need to backprop x2 if x2 is involved in computing v1. | |||||
| // Args: | |||||
| // pending_count: The number of backprop inputs for every op. | |||||
| // to_ops_set: The set of ops for ys in gradients(ys, xs) | |||||
| // Returns: | |||||
| // The set of unused loop exits that we know at this point we need | |||||
| // to backprop. | |||||
| // """ | |||||
| // loop_exits = [] | |||||
| // for grad_state in self._map.values(): | |||||
| // for y in grad_state.forward_loop_exits: | |||||
| // if pending_count[y.op] == 0: | |||||
| // grad_state.pending_exits_count -= 1 | |||||
| // if y.op not in to_ops_set: | |||||
| // grad_state.unused_exits.append(y) | |||||
| // if grad_state.pending_exits_count == 0: | |||||
| // loop_exits.extend(grad_state.unused_exits) | |||||
| // # Need to include Enters in backprop for higher-order gradients. | |||||
| // for y in grad_state.forward_context.loop_enters: | |||||
| // if pending_count[y.op] == 0: | |||||
| // pending_count[y.op] = 1 | |||||
| // return loop_exits | |||||
| // def EnterGradWhileContext(self, op, before): | |||||
| // """Enter the WhileContext for gradient computation.""" | |||||
| // grad_state = self.GetGradState(op, before) | |||||
| // if grad_state: | |||||
| // grad_state.grad_context.Enter() | |||||
| public ControlFlowState() | |||||
| { | |||||
| _map = new Dictionary<ControlFlowContext, GradLoopState>(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Return the grad state for this op if it's in a forward loop context. | |||||
| /// </summary> | |||||
| /// <param name="op"></param> | |||||
| /// <param name="before"></param> | |||||
| /// <returns></returns> | |||||
| public GradLoopState GetGradState(Operation op, bool before) | |||||
| { | |||||
| ControlFlowContext forward_ctxt = null; | |||||
| if (before && util.IsLoopExit(op)) | |||||
| { | |||||
| forward_ctxt = op._get_control_flow_context(); | |||||
| forward_ctxt = forward_ctxt.outer_context; | |||||
| if (forward_ctxt != null) | |||||
| forward_ctxt = forward_ctxt.GetWhileContext(); | |||||
| } | |||||
| else | |||||
| forward_ctxt = util.GetWhileContext(op); | |||||
| if (forward_ctxt != null) | |||||
| return _map.get(forward_ctxt); | |||||
| return null; | |||||
| } | |||||
| public Tensor[] ProcessUnusedLoopExits(Dictionary<string, int> pending_count, List<Operation> to_ops_set) | |||||
| { | |||||
| var loop_exits = new List<Tensor>(); | |||||
| foreach(var grad_state in _map.Values) | |||||
| { | |||||
| foreach(var y in grad_state.forward_loop_exits) | |||||
| { | |||||
| if(!pending_count.ContainsKey(y.op.name)) | |||||
| { | |||||
| grad_state.pending_exits_count -= 1; | |||||
| if (!to_ops_set.Contains(y.op)) | |||||
| grad_state.unused_exits.append(y); | |||||
| if (grad_state.pending_exits_count == 0) | |||||
| loop_exits.extend(grad_state.unused_exits); | |||||
| } | |||||
| } | |||||
| foreach(var y in grad_state.forward_context.loop_enters) | |||||
| { | |||||
| if (!pending_count.ContainsKey(y.op.name)) | |||||
| pending_count[y.op.name] = 1; | |||||
| } | |||||
| } | |||||
| return loop_exits.ToArray(); | |||||
| } | |||||
| public void EnterGradWhileContext(Operation op, bool before) | |||||
| { | |||||
| var grad_state = GetGradState(op, before); | |||||
| if (grad_state != null) | |||||
| grad_state.grad_context.Enter(); | |||||
| } | |||||
| // def ExitGradWhileContext(self, op, before): | // def ExitGradWhileContext(self, op, before): | ||||
| // """Exit the WhileContext for gradient computation.""" | // """Exit the WhileContext for gradient computation.""" | ||||
| @@ -118,6 +141,32 @@ namespace Tensorflow.Operations.ControlFlows | |||||
| // if loop_exit.op not in between_ops: | // if loop_exit.op not in between_ops: | ||||
| // between_ops.add(loop_exit.op) | // between_ops.add(loop_exit.op) | ||||
| // between_op_list.append(loop_exit.op) | // between_op_list.append(loop_exit.op) | ||||
| public void AddWhileContext(Operation op, List<Operation> between_op_list, List<Operation> between_ops) | |||||
| { | |||||
| var forward_ctxt = op.GetWhileContext(); | |||||
| var grad_state = _map.ContainsKey(forward_ctxt) ? _map[forward_ctxt] : null; | |||||
| if(grad_state == null) | |||||
| { | |||||
| GradLoopState outer_grad_state = null; | |||||
| var outer_forward_ctxt = forward_ctxt.outer_context; | |||||
| if (outer_forward_ctxt != null) | |||||
| outer_forward_ctxt = outer_forward_ctxt.GetWhileContext(); | |||||
| if (outer_forward_ctxt != null) | |||||
| outer_grad_state = _map[outer_forward_ctxt]; | |||||
| grad_state = new GradLoopState(forward_ctxt, outer_grad_state); | |||||
| _map[forward_ctxt] = grad_state; | |||||
| // We need to include all exits of a loop for backprop. | |||||
| foreach (var loop_exit in grad_state.forward_loop_exits) | |||||
| { | |||||
| if(!between_ops.Contains(loop_exit.op)) | |||||
| { | |||||
| between_ops.add(loop_exit.op); | |||||
| between_op_list.append(loop_exit.op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| // def ZerosLikeForExit(self, val): | // def ZerosLikeForExit(self, val): | ||||
| // """Create zeros_like gradient for a loop exit. | // """Create zeros_like gradient for a loop exit. | ||||
| @@ -174,70 +223,69 @@ namespace Tensorflow.Operations.ControlFlows | |||||
| // result = array_ops.zeros_like(val, optimize=False) | // result = array_ops.zeros_like(val, optimize=False) | ||||
| // return result | // return result | ||||
| // def ZerosLike(self, op, index): | |||||
| // """Create zeros_like for the specified output of an op. | |||||
| // If op is in a while loop that is part of gradients(), this method | |||||
| // must be called in its grad loop context. | |||||
| // Args: | |||||
| // op: A tensorflow operation. | |||||
| // index: the index for a specific output of the op. | |||||
| // Returns: | |||||
| // A zero tensor of the same shape of op.outputs[index]. | |||||
| // """ | |||||
| // if util.IsLoopSwitch(op): | |||||
| // return None | |||||
| // if op.graph._building_function: # pylint: disable=protected-access | |||||
| // # The optimization here is tricky to apply to functions | |||||
| // return array_ops.zeros_like(op.outputs[index]) | |||||
| // dead_branch = util.IsSwitch(op) | |||||
| // forward_ctxt = _GetWhileContext(op) | |||||
| // grad_state = self._map.get(forward_ctxt) | |||||
| // if grad_state is None: | |||||
| // # op is not in a while loop that is part of gradients(). | |||||
| // return ZerosLikeOutsideLoop(op, index) | |||||
| // op_ctxt = op._get_control_flow_context() | |||||
| // val = ops.convert_to_tensor(op.outputs[index], name="tensor") | |||||
| // shape = val.get_shape() | |||||
| // if shape.is_fully_defined(): | |||||
| // # If the shape is known statically, just create a zero tensor with | |||||
| // # the right shape in the grad loop context. | |||||
| // result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype) | |||||
| // if dead_branch: | |||||
| // # op is a cond switch. Guard the zero tensor with a switch. | |||||
| // pred = grad_state.history_map.get(op_ctxt.pred.name) | |||||
| // branch = op_ctxt.branch | |||||
| // result = _SwitchRefOrTensor(result, pred)[1 - branch] | |||||
| // else: | |||||
| // # Unknown shape so keep a history of the shape at runtime. | |||||
| // if dead_branch: | |||||
| // # Need to add a special switch to guard the value. | |||||
| // pred = op_ctxt.pred | |||||
| // branch = op_ctxt.branch | |||||
| // op_ctxt.outer_context.Enter() | |||||
| // val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch] | |||||
| // zeros_shape = array_ops.shape_internal(val, optimize=False) | |||||
| // op_ctxt.outer_context.Exit() | |||||
| // val.op._set_control_flow_context(op_ctxt) | |||||
| // zeros_shape.op._set_control_flow_context(op_ctxt) | |||||
| // else: | |||||
| // op_ctxt.Enter() | |||||
| // zeros_shape = array_ops.shape_internal(val, optimize=False) | |||||
| // op_ctxt.Exit() | |||||
| // # Add forward accumulator for shape. | |||||
| // grad_state.grad_context.Exit() | |||||
| // history_zeros_shape = grad_state.AddForwardAccumulator( | |||||
| // zeros_shape, dead_branch=dead_branch) | |||||
| // grad_state.grad_context.Enter() | |||||
| // # Create a zero tensor with the right shape. | |||||
| // shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape, | |||||
| // zeros_shape, dead_branch) | |||||
| // result = array_ops.zeros(shape, val.dtype) | |||||
| // return result | |||||
| public Tensor ZerosLike(Operation op, int index) | |||||
| { | |||||
| if (util.IsLoopSwitch(op)) | |||||
| return null; | |||||
| if (op.graph.building_function) | |||||
| return array_ops.zeros_like(op.outputs[index]); | |||||
| var dead_branch = util.IsSwitch(op); | |||||
| var forward_ctxt = util.GetWhileContext(op); | |||||
| var grad_state = _map.get(forward_ctxt); | |||||
| // op is not in a while loop that is part of gradients(). | |||||
| if (grad_state == null) | |||||
| return ZerosLikeOutsideLoop(op, index); | |||||
| throw new NotImplementedException("ZerosLike"); | |||||
| } | |||||
| public Tensor ZerosLikeOutsideLoop(Operation op, int index) | |||||
| { | |||||
| var val = op.outputs[index]; | |||||
| if (!util.IsSwitch(op)) | |||||
| { | |||||
| if (val.dtype == dtypes.resource) | |||||
| throw new NotImplementedException("ZerosLikeOutsideLoop"); | |||||
| /*return array_ops.zeros( | |||||
| gen_resource_variable_ops.variable_shape(val), | |||||
| dtype: default_gradient.get_zeros_dtype(val));*/ | |||||
| return array_ops.zeros_like(val, optimize: false); | |||||
| } | |||||
| else | |||||
| throw new NotImplementedException("ZerosLikeOutsideLoop"); | |||||
| } | |||||
| /// <summary> | |||||
| /// Create zeros_like gradient for a loop exit. | |||||
| /// </summary> | |||||
| /// <param name="val"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor ZerosLikeForExit(Tensor val) | |||||
| { | |||||
| Tensor result = null; | |||||
| var val_shape = val.TensorShape; | |||||
| var forward_ctxt = val.op._get_control_flow_context(); | |||||
| var outer_forward_ctxt = forward_ctxt.outer_context; | |||||
| if (outer_forward_ctxt != null) | |||||
| outer_forward_ctxt = outer_forward_ctxt.GetWhileContext(); | |||||
| GradLoopState outer_grad_state = null; | |||||
| if (outer_forward_ctxt != null) | |||||
| outer_grad_state = _map.get(outer_forward_ctxt); | |||||
| // This is a nested loop. | |||||
| if (outer_grad_state != null) | |||||
| { | |||||
| throw new NotImplementedException("ZerosLikeForExit"); | |||||
| } | |||||
| else | |||||
| { | |||||
| // If the shape is known statically, just create a zero tensor | |||||
| // with the right shape. | |||||
| if (val_shape.is_fully_defined()) | |||||
| result = array_ops.zeros(val_shape.dims, val.dtype); | |||||
| else | |||||
| result = array_ops.zeros_like(val, optimize: false); | |||||
| } | |||||
| return result; | |||||
| } | |||||
| // def PostProcessing(self): | // def PostProcessing(self): | ||||
| // """Perform postprocessing at the end of gradients(). | // """Perform postprocessing at the end of gradients(). | ||||
| @@ -16,41 +16,16 @@ | |||||
| using System; | using System; | ||||
| using System.Collections; | using System.Collections; | ||||
| using System.Collections.Generic; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Operations.ControlFlows | namespace Tensorflow.Operations.ControlFlows | ||||
| { | { | ||||
| /// <summary> | |||||
| /// The state used for constructing the gradient graph for a while loop. | |||||
| /// </summary> | |||||
| public class GradLoopState | public class GradLoopState | ||||
| { | { | ||||
| //class GradLoopState(object): | |||||
| // """The state used for constructing the gradient graph for a while loop. | |||||
| // We create a GradLoopState for each while loop in forward and its | |||||
| // corresponding while loop in backprop. This gives us access to both | |||||
| // the forward and the backprop WhileContexts. | |||||
| // During the construction of gradient graph, any time when we detect | |||||
| // a forward value that is needed for backprop, we create a history | |||||
| // accumulator and add it to `history_map`. Any time when we backprop | |||||
| // a loop switch op (in _SwitchGrad), we add the grad merge op in | |||||
| // `switch_map`. | |||||
| // """ | |||||
| // def __init__(self, forward_ctxt, outer_grad_state): | |||||
| // # The grad loop state for the outer while loop. | |||||
| // self._outer_grad_state = None | |||||
| // # The while loop context for forward. | |||||
| // self._forward_context = None | |||||
| // # The loop counter added by AddForwardLoopCounter. It is the value | |||||
| // # of the loop counter for the next iteration. | |||||
| // self._forward_index = None | |||||
| // # A sync op for forward. | |||||
| // self._forward_sync = None | |||||
| // # The while loop context for backprop. | |||||
| private WhileContext _grad_context = null; | private WhileContext _grad_context = null; | ||||
| public WhileContext grad_context => _grad_context; | public WhileContext grad_context => _grad_context; | ||||
| @@ -65,156 +40,91 @@ namespace Tensorflow.Operations.ControlFlows | |||||
| // # Information needed by backprop. | // # Information needed by backprop. | ||||
| private Hashtable _history_map = new Hashtable(); | private Hashtable _history_map = new Hashtable(); | ||||
| public Hashtable history_map => _history_map; | public Hashtable history_map => _history_map; | ||||
| private Hashtable _switch_map = new Hashtable(); | |||||
| public Hashtable switch_map => _switch_map; | |||||
| // self._unused_exits = [] | |||||
| // self._deferred_exits = [] | |||||
| // self._forward_loop_exits = list(forward_ctxt.loop_exits) | |||||
| // self._pending_exits_count = len(forward_ctxt.loop_exits) | |||||
| // self._outer_grad_state = outer_grad_state | |||||
| // if outer_grad_state: | |||||
| // outer_forward_ctxt = outer_grad_state.forward_context | |||||
| // else: | |||||
| // if not hasattr(forward_ctxt, "outer_context"): | |||||
| // raise ValueError("Failed to call gradients on a while loop without" | |||||
| // "properly serializing graph via MetaGraphDef") | |||||
| // outer_forward_ctxt = forward_ctxt.outer_context | |||||
| // # Add the forward loop counter. | |||||
| // with forward_ctxt._graph.as_default(): # pylint: disable=protected-access | |||||
| // if outer_forward_ctxt: | |||||
| // outer_forward_ctxt.Enter() | |||||
| // cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) | |||||
| // if outer_forward_ctxt: | |||||
| // outer_forward_ctxt.Exit() | |||||
| // self._forward_context = forward_ctxt | |||||
| // self._forward_index = forward_index | |||||
| // # Add the backprop WhileContext, and the backprop loop counter. | |||||
| // if outer_grad_state: | |||||
| // # This is a nested loop. Remember the iteration counts for each | |||||
| // # execution of this inner loop. | |||||
| // outer_forward_ctxt.AddName(cnt.name) | |||||
| // history_cnt = outer_grad_state.AddForwardAccumulator(cnt) | |||||
| // outer_grad_ctxt = outer_grad_state.grad_context | |||||
| // outer_grad_ctxt.Enter() | |||||
| // self._grad_context = WhileContext( | |||||
| // maximum_iterations=forward_ctxt.maximum_iterations, | |||||
| // parallel_iterations=forward_ctxt.parallel_iterations, | |||||
| // back_prop=forward_ctxt.back_prop, | |||||
| // swap_memory=forward_ctxt.swap_memory, | |||||
| // name=forward_ctxt.name, | |||||
| // grad_state=self) | |||||
| // real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt) | |||||
| // self._grad_index = self._grad_context.AddBackpropLoopCounter( | |||||
| // real_cnt, outer_grad_state) | |||||
| // outer_grad_ctxt.Exit() | |||||
| // else: | |||||
| // if outer_forward_ctxt: | |||||
| // outer_forward_ctxt.Enter() | |||||
| // self._grad_context = WhileContext( | |||||
| // maximum_iterations=forward_ctxt.maximum_iterations, | |||||
| // parallel_iterations=forward_ctxt.parallel_iterations, | |||||
| // back_prop=forward_ctxt.back_prop, | |||||
| // swap_memory=forward_ctxt.swap_memory, | |||||
| // name=forward_ctxt.name, | |||||
| // grad_state=self) | |||||
| // self._grad_index = self._grad_context.AddBackpropLoopCounter( | |||||
| // cnt, outer_grad_state) | |||||
| // if outer_forward_ctxt: | |||||
| // outer_forward_ctxt.Exit() | |||||
| // @property | |||||
| // def outer_grad_state(self): | |||||
| // """The grad loop state for outer loop.""" | |||||
| // return self._outer_grad_state | |||||
| // @property | |||||
| // def forward_context(self): | |||||
| // """The while loop context for forward.""" | |||||
| // return self._forward_context | |||||
| // @property | |||||
| // def forward_index(self): | |||||
| // """The loop index of forward loop.""" | |||||
| // return self._forward_index | |||||
| // @property | |||||
| // def forward_sync(self): | |||||
| // """A control trigger node for synchronization in the forward loop. | |||||
| // One main use is to keep the push ops of a stack executed in the | |||||
| // iteration order. | |||||
| // """ | |||||
| // if self._forward_sync is None: | |||||
| // with ops.control_dependencies(None): | |||||
| // self._forward_sync = control_trigger(name="f_sync") | |||||
| // self._forward_sync._set_control_flow_context(self._forward_context) | |||||
| // self._forward_index.op._add_control_input(self._forward_sync) | |||||
| // return self._forward_sync | |||||
| // @property | |||||
| // def grad_context(self): | |||||
| // """The corresponding WhileContext for gradient.""" | |||||
| // return self._grad_context | |||||
| // @property | |||||
| // def grad_index(self): | |||||
| // """The loop index of backprop loop.""" | |||||
| // return self._grad_index | |||||
| // @property | |||||
| // def grad_sync(self): | |||||
| // """A control trigger node for synchronization in the grad loop. | |||||
| // One main use is to keep the pop ops of a stack executed in the | |||||
| // iteration order. | |||||
| // """ | |||||
| // if self._grad_sync is None: | |||||
| // with ops.control_dependencies(None): | |||||
| // self._grad_sync = control_trigger(name="b_sync") | |||||
| // self._grad_sync._set_control_flow_context(self._grad_context) | |||||
| // self._grad_index.op._add_control_input(self._grad_sync) | |||||
| // if self._grad_context.outer_context: | |||||
| // self._grad_context.outer_context.AddInnerOp(self._grad_sync) | |||||
| // return self._grad_sync | |||||
| // @property | |||||
| // def history_map(self): | |||||
| // """The map that records all the tensors needed for backprop.""" | |||||
| // return self._history_map | |||||
| // @property | |||||
| // def switch_map(self): | |||||
| // """The map that records all the Switch ops for the while loop.""" | |||||
| // return self._switch_map | |||||
| // @property | |||||
| // def unused_exits(self): | |||||
| // """The list of "unused" exits.""" | |||||
| // return self._unused_exits | |||||
| // @property | |||||
| // def deferred_exits(self): | |||||
| // """The list of "deferred" exits.""" | |||||
| // return self._deferred_exits | |||||
| // @property | |||||
| // def forward_loop_exits(self): | |||||
| // """The list of exits of the forward loop.""" | |||||
| // return self._forward_loop_exits | |||||
| // @property | |||||
| // def pending_exits_count(self): | |||||
| // """The number of exits we expect to see but haven't.""" | |||||
| // return self._pending_exits_count | |||||
| // @pending_exits_count.setter | |||||
| // def pending_exits_count(self, cnt): | |||||
| // """Set the pending count to cnt.""" | |||||
| // self._pending_exits_count = cnt | |||||
| Dictionary<Operation, Tensor> _switch_map = new Dictionary<Operation, Tensor>(); | |||||
| public Dictionary<Operation, Tensor> switch_map => _switch_map; | |||||
| /// <summary> | |||||
| /// The while loop context for forward. | |||||
| /// </summary> | |||||
| WhileContext _forward_context; | |||||
| public WhileContext forward_context => _forward_context; | |||||
| /// <summary> | |||||
| /// The grad loop state for the outer while loop. | |||||
| /// </summary> | |||||
| GradLoopState _outer_grad_state; | |||||
| public GradLoopState outer_grad_state => _outer_grad_state; | |||||
| Tensor _forward_index; | |||||
| Tensor _grad_index; | |||||
| Tensor[] _forward_loop_exits; | |||||
| /// <summary> | |||||
| /// The list of exits of the forward loop. | |||||
| /// </summary> | |||||
| public Tensor[] forward_loop_exits => _forward_loop_exits; | |||||
| List<Tensor> _deferred_exits; | |||||
| public List<Tensor> deferred_exits => _deferred_exits; | |||||
| List<Tensor> _unused_exits; | |||||
| public List<Tensor> unused_exits => _unused_exits; | |||||
| /// <summary> | |||||
| /// The number of exits we expect to see but haven't. | |||||
| /// </summary> | |||||
| public int pending_exits_count { get; set; } | |||||
| public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_) | |||||
| { | |||||
| // Information needed by backprop. | |||||
| _unused_exits = new List<Tensor>(); | |||||
| _deferred_exits = new List<Tensor>(); | |||||
| _forward_loop_exits = list(forward_ctxt.loop_exits); | |||||
| pending_exits_count = len(forward_ctxt.loop_exits); | |||||
| _outer_grad_state = outer_grad_state_; | |||||
| ControlFlowContext outer_forward_ctxt = null; | |||||
| if (outer_grad_state_ != null) | |||||
| outer_forward_ctxt = outer_grad_state_.forward_context; | |||||
| // Add the forward loop counter. | |||||
| // with forward_ctxt._graph.as_default(): | |||||
| Tensor cnt, forward_index; | |||||
| { | |||||
| if (outer_forward_ctxt != null) | |||||
| outer_forward_ctxt.Enter(); | |||||
| (cnt, forward_index) = forward_ctxt.AddForwardLoopCounter(outer_grad_state); | |||||
| if (outer_forward_ctxt != null) | |||||
| outer_forward_ctxt.Exit(); | |||||
| } | |||||
| _forward_context = forward_ctxt; | |||||
| _forward_index = forward_index; | |||||
| // Add the backprop WhileContext, and the backprop loop counter. | |||||
| if (outer_grad_state != null) | |||||
| { | |||||
| // This is a nested loop. Remember the iteration counts for each | |||||
| // execution of this inner loop. | |||||
| throw new NotImplementedException("GradLoopState"); | |||||
| } | |||||
| else | |||||
| { | |||||
| if (outer_forward_ctxt != null) | |||||
| outer_forward_ctxt.Enter(); | |||||
| _grad_context = new WhileContext( | |||||
| maximum_iterations: forward_ctxt.maximum_iterations, | |||||
| parallel_iterations: forward_ctxt.parallel_iterations, | |||||
| back_prop: forward_ctxt.back_prop, | |||||
| swap_memory: forward_ctxt.swap_memory, | |||||
| name: forward_ctxt.name, | |||||
| grad_state: this); | |||||
| _grad_index = _grad_context.AddBackpropLoopCounter(cnt, outer_grad_state); | |||||
| if (outer_forward_ctxt != null) | |||||
| outer_forward_ctxt.Exit(); | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Add an accumulator for each forward tensor that is needed in backprop. | /// Add an accumulator for each forward tensor that is needed in backprop. | ||||
| @@ -0,0 +1,36 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Operations | |||||
| { | |||||
| public class MergeOutput | |||||
| { | |||||
| Tensor output; | |||||
| Tensor value_index; | |||||
| public MergeOutput(Tensor[] values) | |||||
| { | |||||
| output = values[0]; | |||||
| value_index = values[1]; | |||||
| } | |||||
| public Tensor this[int idx] | |||||
| { | |||||
| get | |||||
| { | |||||
| switch(idx) | |||||
| { | |||||
| case 0: | |||||
| return output; | |||||
| case 1: | |||||
| return value_index; | |||||
| default: | |||||
| return null; | |||||
| } | |||||
| } | |||||
| } | |||||
| public static implicit operator Tensor(MergeOutput merge) | |||||
| => merge.output; | |||||
| } | |||||
| } | |||||
| @@ -32,12 +32,17 @@ namespace Tensorflow.Operations | |||||
| bool _back_prop=true; | bool _back_prop=true; | ||||
| GradLoopState _grad_state =null; | GradLoopState _grad_state =null; | ||||
| Tensor _maximum_iterations; | Tensor _maximum_iterations; | ||||
| public Tensor maximum_iterations => _maximum_iterations; | |||||
| int _parallel_iterations; | int _parallel_iterations; | ||||
| public int parallel_iterations => _parallel_iterations; | |||||
| bool _swap_memory; | bool _swap_memory; | ||||
| public bool swap_memory => _swap_memory; | |||||
| Tensor _pivot_for_pred; | Tensor _pivot_for_pred; | ||||
| Tensor _pivot_for_body; | Tensor _pivot_for_body; | ||||
| List<Tensor> _loop_exits; | List<Tensor> _loop_exits; | ||||
| public List<Tensor> loop_exits => _loop_exits; | |||||
| List<Tensor> _loop_enters; | List<Tensor> _loop_enters; | ||||
| public List<Tensor> loop_enters => _loop_enters; | |||||
| Graph _graph; | Graph _graph; | ||||
| public override GradLoopState grad_state => _grad_state; | public override GradLoopState grad_state => _grad_state; | ||||
| public override bool back_prop => _back_prop; | public override bool back_prop => _back_prop; | ||||
| @@ -109,7 +114,7 @@ namespace Tensorflow.Operations | |||||
| /// <summary> | /// <summary> | ||||
| /// Add the loop termination condition and body to the graph. | /// Add the loop termination condition and body to the graph. | ||||
| /// </summary> | /// </summary> | ||||
| internal Tensor[] BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred, | |||||
| internal LoopVar<TItem> BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred, | |||||
| Func<LoopVar<TItem>, LoopVar<TItem>> body, | Func<LoopVar<TItem>, LoopVar<TItem>> body, | ||||
| LoopVar<TItem> loop_vars, | LoopVar<TItem> loop_vars, | ||||
| TensorShape[] shape_invariants, | TensorShape[] shape_invariants, | ||||
| @@ -132,14 +137,16 @@ namespace Tensorflow.Operations | |||||
| pred, body, original_loop_vars, loop_vars_tensors, shape_invariants); | pred, body, original_loop_vars, loop_vars_tensors, shape_invariants); | ||||
| Exit(); | Exit(); | ||||
| var flat_result = original_body_result; | |||||
| var flat_result = nest.flatten2(original_body_result) | |||||
| .Select(x => x as ITensorOrTensorArray) | |||||
| .ToArray(); | |||||
| var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars); | var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars); | ||||
| var packed_exit_vars = nest.pack_sequence_as( | |||||
| var packed_exit_vars = nest.pack_sequence_as2( | |||||
| structure: original_body_result, | structure: original_body_result, | ||||
| flat_sequence: exit_vars_with_tensor_arrays); | flat_sequence: exit_vars_with_tensor_arrays); | ||||
| return packed_exit_vars as Tensor[]; | |||||
| return packed_exit_vars; | |||||
| } | } | ||||
| private Tensor _convert_tensorarray_to_flow(object tensor_or_tensor_array) | private Tensor _convert_tensorarray_to_flow(object tensor_or_tensor_array) | ||||
| @@ -167,7 +174,7 @@ namespace Tensorflow.Operations | |||||
| /// <param name="loop_vars"></param> | /// <param name="loop_vars"></param> | ||||
| /// <param name="shape_invariants"></param> | /// <param name="shape_invariants"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred, | |||||
| private (LoopVar<TItem>, Tensor[]) _BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred, | |||||
| Func<LoopVar<TItem>, LoopVar<TItem>> body, | Func<LoopVar<TItem>, LoopVar<TItem>> body, | ||||
| LoopVar<TItem> original_loop_vars, | LoopVar<TItem> original_loop_vars, | ||||
| Tensor[] loop_vars, | Tensor[] loop_vars, | ||||
| @@ -221,6 +228,7 @@ namespace Tensorflow.Operations | |||||
| var merge_vars = enter_vars | var merge_vars = enter_vars | ||||
| .Select(x => merge(new[] { x, x })) | .Select(x => merge(new[] { x, x })) | ||||
| .Select(m => (Tensor)m) | |||||
| .ToArray(); | .ToArray(); | ||||
| _pivot_for_pred = merge_vars[0]; | _pivot_for_pred = merge_vars[0]; | ||||
| @@ -250,13 +258,15 @@ namespace Tensorflow.Operations | |||||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | ||||
| // Store body_result to keep track of TensorArrays returned by body | // Store body_result to keep track of TensorArrays returned by body | ||||
| var original_body_result = new[] { body_result }; | |||||
| var original_body_result = body_result; | |||||
| // Convert TensorArrays returned by body into their flow variables | // Convert TensorArrays returned by body into their flow variables | ||||
| var result = new[] { body_result }; | |||||
| var result = nest.flatten2(body_result) | |||||
| .Select(x => _convert_tensorarray_to_flow(x)) | |||||
| .ToArray(); | |||||
| // result = ops.convert_n_to_tensor_or_composite(result); | |||||
| var next_vars = new List<Tensor>(); | var next_vars = new List<Tensor>(); | ||||
| //foreach (var (m, v) in zip(merge_vars, result)) | |||||
| //next_vars.Add(_AddNextAndBackEdge(m, v)); | |||||
| foreach (var (m, v) in zip(merge_vars, result)) | |||||
| next_vars.Add(_AddNextAndBackEdge(m, v)); | |||||
| // Add the exit ops. | // Add the exit ops. | ||||
| var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); | var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); | ||||
| @@ -264,7 +274,7 @@ namespace Tensorflow.Operations | |||||
| // Exit the loop. | // Exit the loop. | ||||
| // ExitResult(exit_vars); | // ExitResult(exit_vars); | ||||
| return (null, exit_vars.ToArray()); | |||||
| return (original_body_result, exit_vars.ToArray()); | |||||
| } | } | ||||
| private void _FixControlInputsAndContext(Tensor[] enters) | private void _FixControlInputsAndContext(Tensor[] enters) | ||||
| @@ -282,7 +292,18 @@ namespace Tensorflow.Operations | |||||
| var keep_as_control_input = true; | var keep_as_control_input = true; | ||||
| var op_ctxt = control_flow_util.GetOutputContext(op); | var op_ctxt = control_flow_util.GetOutputContext(op); | ||||
| var outer_ctxt = outer_context; | var outer_ctxt = outer_context; | ||||
| throw new NotImplementedException(""); | |||||
| var outer_while_context = outer_ctxt == null ? null : outer_ctxt.GetWhileContext(); | |||||
| while (outer_ctxt != op_ctxt) | |||||
| { | |||||
| if (outer_ctxt == null || outer_ctxt == outer_while_context) | |||||
| { | |||||
| keep_as_control_input = false; | |||||
| break; | |||||
| } | |||||
| outer_ctxt = outer_ctxt.outer_context; | |||||
| } | |||||
| if (keep_as_control_input) | |||||
| outer_control_inputs.append(op); | |||||
| } | } | ||||
| // op for op in control_inputs if self._IsInOuterContext(op) | // op for op in control_inputs if self._IsInOuterContext(op) | ||||
| /*var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) | /*var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) | ||||
| @@ -307,10 +328,21 @@ namespace Tensorflow.Operations | |||||
| protected override void _AddOpInternal(Operation op) | protected override void _AddOpInternal(Operation op) | ||||
| { | { | ||||
| if (op.name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad") | |||||
| { | |||||
| } | |||||
| Operation[] external_inputs = new Operation[0]; | Operation[] external_inputs = new Operation[0]; | ||||
| Operation[] control_inputs = new Operation[0]; | |||||
| if (op.inputs.Length == 0) | if (op.inputs.Length == 0) | ||||
| { | { | ||||
| throw new NotImplementedException(""); | |||||
| // Remove any external control dependency on this op | |||||
| (control_inputs, external_inputs) = _RemoveExternalControlEdges(op); | |||||
| if (control_inputs.Length == 0) | |||||
| op._add_control_input(GetControlPivot().op); | |||||
| foreach (var x in op.outputs) | |||||
| _values.Add(x.name); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -378,6 +410,93 @@ namespace Tensorflow.Operations | |||||
| _AddOpInternal(op); | _AddOpInternal(op); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Adds a loop that counts the number of iterations. | |||||
| /// </summary> | |||||
| /// <param name="outer_grad_state">The outer grad state. None if not nested.</param> | |||||
| /// <returns>The number of iterations taken by the forward loop and the loop index.</returns> | |||||
| public (Tensor, Tensor) AddForwardLoopCounter(GradLoopState outer_grad_state) | |||||
| { | |||||
| var n = constant_op.constant(0, name: "f_count"); | |||||
| if (outer_grad_state != null) | |||||
| throw new NotImplementedException("AddForwardLoopCounter"); | |||||
| Enter(); | |||||
| AddName(n.name); | |||||
| var enter_n = _Enter(n, | |||||
| _name, | |||||
| is_constant: false, | |||||
| parallel_iterations: _parallel_iterations, | |||||
| name: "f_count"); | |||||
| _loop_enters.Add(enter_n); | |||||
| var m1 = merge(new[] { enter_n, enter_n }); | |||||
| var merge_n = m1[0]; | |||||
| var switch_n = @switch (merge_n, _pivot); | |||||
| var index = math_ops.add(switch_n[1], 1); | |||||
| var next_n = _NextIteration(index); | |||||
| merge_n.op._update_input(1, next_n); | |||||
| var total_iterations = exit(switch_n[0], name: "f_count"); | |||||
| loop_exits.append(total_iterations); | |||||
| ExitResult(new[] { total_iterations }); | |||||
| Exit(); | |||||
| return (total_iterations, next_n); | |||||
| } | |||||
| /// <summary> | |||||
| /// Add the backprop loop that controls the iterations. | |||||
| /// </summary> | |||||
| /// <param name="count">The number of iterations for backprop.</param> | |||||
| /// <param name="outer_grad_state">The outer grad state. None if not nested.</param> | |||||
| /// <returns>The loop index.</returns> | |||||
| public Tensor AddBackpropLoopCounter(Tensor count, GradLoopState outer_grad_state) | |||||
| { | |||||
| Tensor one = null; | |||||
| var in_separate_functions = count.graph != ops.get_default_graph(); | |||||
| if (in_separate_functions) | |||||
| // Brings the count into this graph | |||||
| count = array_ops.identity(count); | |||||
| else | |||||
| one = constant_op.constant(1, name: "b_count"); | |||||
| Enter(); | |||||
| AddName(count.name); | |||||
| var enter_count = _Enter( | |||||
| count, | |||||
| _name, | |||||
| is_constant: false, | |||||
| parallel_iterations: _parallel_iterations, | |||||
| name: "b_count"); | |||||
| loop_enters.append(enter_count); | |||||
| var merge_count = merge(new[] { enter_count, enter_count })[0]; | |||||
| _pivot_for_pred = merge_count; | |||||
| if (in_separate_functions) | |||||
| one = constant_op.constant(1, name: "b_count"); | |||||
| var pred = math_ops.greater_equal(merge_count, one); | |||||
| _pivot = gen_control_flow_ops.loop_cond(pred, name: "b_count"); | |||||
| var switch_count = @switch(merge_count, _pivot); | |||||
| var index = math_ops.subtract(switch_count[1], one); | |||||
| _pivot_for_body = index; | |||||
| var next_count = _NextIteration(index); | |||||
| merge_count.op._update_input(1, next_count); | |||||
| var final_zero = exit(switch_count[0], name: "b_count"); | |||||
| loop_exits.append(final_zero); | |||||
| // Force the stack pops of i-th execution of an inner loop to be ordered | |||||
| // before the pops of (i+1)-th execution of the same inner loop. | |||||
| if (outer_grad_state != null) | |||||
| throw new NotImplementedException("outer_grad_state"); | |||||
| //outer_grad_state.grad_sync._add_control_input(final_zero.op); | |||||
| ExitResult(new[] { final_zero }); | |||||
| Exit(); | |||||
| return next_count; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Add `val` to the current context and its outer context recursively. | /// Add `val` to the current context and its outer context recursively. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -401,17 +520,27 @@ namespace Tensorflow.Operations | |||||
| grad_ctxt = grad_ctxt.GetWhileContext(); | grad_ctxt = grad_ctxt.GetWhileContext(); | ||||
| if (grad_ctxt.grad_state != null) | if (grad_ctxt.grad_state != null) | ||||
| { | { | ||||
| throw new NotImplementedException(""); | |||||
| var forward_ctxt = val.op.GetWhileContext(); | |||||
| if (control_flow_util.IsLoopExit(val.op)) | |||||
| { | |||||
| forward_ctxt = forward_ctxt.outer_context as WhileContext; | |||||
| if (forward_ctxt != null) | |||||
| forward_ctxt = forward_ctxt.GetWhileContext(); | |||||
| throw new NotImplementedException("control_flow_util.IsLoopExit"); | |||||
| } | |||||
| if(forward_ctxt == grad_ctxt.grad_state.forward_context) | |||||
| { | |||||
| throw new NotImplementedException("forward_ctxt == grad_ctxt.grad_state.forward_context"); | |||||
| /*real_val = grad_ctxt.grad_state.GetRealValue(val); | |||||
| _external_values[val.name] = real_val; | |||||
| return real_val;*/ | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| if (_outer_context != null) | if (_outer_context != null) | ||||
| result = _outer_context.AddValue(val); | result = _outer_context.AddValue(val); | ||||
| if (tf.get_default_graph()._nodes_by_name.Count >= 83) | |||||
| { | |||||
| } | |||||
| // Create an Enter to make `result` known to this loop context. | // Create an Enter to make `result` known to this loop context. | ||||
| Tensor enter = null; | Tensor enter = null; | ||||
| tf_with(ops.control_dependencies(null), delegate | tf_with(ops.control_dependencies(null), delegate | ||||
| @@ -443,6 +572,9 @@ namespace Tensorflow.Operations | |||||
| return result; | return result; | ||||
| } | } | ||||
| public override bool IsWhileContext() | |||||
| => true; | |||||
| public override WhileContext GetWhileContext() | public override WhileContext GetWhileContext() | ||||
| { | { | ||||
| return this; | return this; | ||||