gradients/flow control: added much missing structuretags/v0.9
| @@ -1,12 +1,79 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow.Gradients | |||
| { | |||
| /// <summary> | |||
| /// Gradients for operators defined in control_flow_ops.py.cs | |||
| /// </summary> | |||
| public class control_flow_grad | |||
| { | |||
| /// <summary> | |||
| /// Gradients for a Switch op is calculated using a Merge op. | |||
| /// | |||
| /// If the switch is a loop switch, it will be visited twice. We create | |||
| /// the merge on the first visit, and update the other input of the merge | |||
| /// on the second visit. A next_iteration is also added on second visit. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads) | |||
| { | |||
| throw new NotImplementedException("_SwitchGrad"); | |||
| //graph = ops.get_default_graph() | |||
| //# pylint: disable=protected-access | |||
| //op_ctxt = op._get_control_flow_context() | |||
| //grad_ctxt = graph._get_control_flow_context() | |||
| //# pylint: enable=protected-access | |||
| //if isinstance(op_ctxt, WhileContext): | |||
| // merge_grad = grad_ctxt.grad_state.switch_map.get(op) | |||
| // if merge_grad is not None: | |||
| // # This is the second time this Switch is visited. It comes from | |||
| // # the non-exit branch of the Switch, so update the second input | |||
| // # to the Merge. | |||
| // # TODO(yuanbyu): Perform shape inference with this new input. | |||
| // if grad[1] is not None: | |||
| // # pylint: disable=protected-access | |||
| // control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1], | |||
| // enforce_shape_invariant=False) | |||
| // # pylint: enable=protected-access | |||
| // return None, None | |||
| // elif grad[0] is not None: | |||
| // # This is the first time this Switch is visited. It comes from | |||
| // # the Exit branch, which is grad[0]. grad[1] is empty at this point. | |||
| // # Use grad[0] for both inputs to merge for now, but update the second | |||
| // # input of merge when we see this Switch the second time. | |||
| // merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] | |||
| // grad_ctxt.grad_state.switch_map[op] = merge_grad | |||
| // return merge_grad, None | |||
| // else: | |||
| // # This is the first time this Switch is visited. It comes from the | |||
| // # Identity branch. Such a Switch has `None` gradient for the Exit branch, | |||
| // # meaning the output is not differentiable. | |||
| // return None, None | |||
| //elif isinstance(op_ctxt, CondContext): | |||
| // zero_grad = grad[1 - op_ctxt.branch] | |||
| // # At this point, we have created zero_grad guarded by the right switch. | |||
| // # Unfortunately, we may still get None here for not trainable data types. | |||
| // if zero_grad is None: | |||
| // # For resource variables we get None always on the other branch, so bypass | |||
| // # this. | |||
| // if op.inputs[0].dtype == dtypes.resource: | |||
| // return merge( | |||
| // [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None | |||
| // return None, None | |||
| // return merge(grad, name="cond_grad")[0], None | |||
| //else: | |||
| // false_grad = switch(grad[0], op.inputs[1])[0] | |||
| // true_grad = switch(grad[1], op.inputs[1])[1] | |||
| // return merge([false_grad, true_grad])[0], None | |||
| } | |||
| /// <summary> | |||
| /// Gradients for a Merge op are calculated using a Switch op. | |||
| /// </summary> | |||
| public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) | |||
| { | |||
| var grad = grads[0]; | |||
| @@ -14,10 +81,164 @@ namespace Tensorflow.Gradients | |||
| var input_op = op.inputs[0].op; | |||
| var graph = ops.get_default_graph(); | |||
| var op_ctxt = control_flow_util.GetOutputContext(input_op); | |||
| var pred = (op_ctxt as CondContext).pred; | |||
| var grad_ctxt = graph._get_control_flow_context(); | |||
| switch (op_ctxt) | |||
| { | |||
| case WhileContext cwhile: | |||
| { | |||
| return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot); | |||
| } | |||
| case CondContext ccond: | |||
| { | |||
| var pred = ccond.pred; | |||
| if (grad_ctxt != null && grad_ctxt.grad_state != null) | |||
| { | |||
| //# This Merge node is part of a cond within a loop. | |||
| //# The backprop needs to have the value of this predicate for every | |||
| //# iteration. So we must have its values accumulated in the forward, and | |||
| //# use the accumulated values as the predicate for this backprop switch. | |||
| var grad_state = grad_ctxt.grad_state; | |||
| var real_pred = grad_state.history_map[pred.name] as Tensor; | |||
| if (real_pred == null) | |||
| { | |||
| //# Remember the value of pred for every iteration. | |||
| grad_ctxt = grad_state.grad_context; | |||
| grad_ctxt.Exit(); | |||
| var history_pred = grad_state.AddForwardAccumulator(pred); | |||
| grad_ctxt.Enter(); | |||
| //# Add the stack pop op. If pred.op is in a (outer) CondContext, | |||
| //# the stack pop will be guarded with a switch. | |||
| real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred); | |||
| grad_state.history_map[pred.name] = real_pred; | |||
| } | |||
| pred = real_pred; | |||
| } | |||
| var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); | |||
| return results; | |||
| } | |||
| default: | |||
| { | |||
| var num_inputs = op.inputs.Length; | |||
| var cond = new Tensor[num_inputs]; | |||
| for (int i = 0; i < num_inputs; i++) | |||
| cond[i] = math_ops.equal(op.outputs[1], i); | |||
| var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray(); | |||
| return result; | |||
| } | |||
| } | |||
| var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); | |||
| return new Tensor[] { results.Item1, results.Item2 }; | |||
| } | |||
| } | |||
| public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) | |||
| { | |||
| return _MergeGrad(op, grads); | |||
| } | |||
| /// <summary> | |||
| /// Gradients for an exit op are calculated using an Enter op. | |||
| /// </summary> | |||
| public Tensor[] _ExitGrad(Operation op, Tensor[] grads) | |||
| { | |||
| throw new NotImplementedException("_ExitGrad"); | |||
| // graph = ops.get_default_graph() | |||
| //# pylint: disable=protected-access | |||
| // op_ctxt = op._get_control_flow_context() | |||
| // grad_ctxt = graph._get_control_flow_context() | |||
| // # pylint: enable=protected-access | |||
| // if not grad_ctxt.back_prop: | |||
| // # The flag `back_prop` is set by users to suppress gradient | |||
| // # computation for this loop. If the attribute `back_prop` is false, | |||
| // # no gradient computation. | |||
| // return None | |||
| // if op_ctxt.grad_state: | |||
| // raise TypeError("Second-order gradient for while loops not supported.") | |||
| // if isinstance(grad, ops.Tensor) : | |||
| // grad_ctxt.AddName(grad.name) | |||
| // else: | |||
| // if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)): | |||
| // raise TypeError("Type %s not supported" % type(grad)) | |||
| // grad_ctxt.AddName(grad.values.name) | |||
| // grad_ctxt.AddName(grad.indices.name) | |||
| // dense_shape = grad.dense_shape | |||
| // if dense_shape is not None: | |||
| // grad_ctxt.AddName(dense_shape.name) | |||
| // grad_ctxt.Enter() | |||
| // # pylint: disable=protected-access | |||
| // result = control_flow_ops._Enter( | |||
| // grad, grad_ctxt.name, is_constant=False, | |||
| // parallel_iterations=grad_ctxt.parallel_iterations, | |||
| // name="b_exit") | |||
| // # pylint: enable=protected-access | |||
| // grad_ctxt.loop_enters.append(result) | |||
| // grad_ctxt.Exit() | |||
| // return result | |||
| } | |||
| /// <summary> | |||
| /// A forward next_iteration is translated into a backprop identity. | |||
| /// | |||
| /// Note that the backprop next_iteration is added in switch grad. | |||
| /// </summary> | |||
| public (object, Tensor[]) _NextIterationGrad(object _, Tensor[] grad) | |||
| { | |||
| return (_, grad); | |||
| } | |||
| public (object, Tensor[]) _RefNextIterationGrad(object _, Tensor[] grad) | |||
| { | |||
| return (_, grad); | |||
| } | |||
| /// <summary> | |||
| /// Gradients for an Enter are calculated using an Exit op. | |||
| /// | |||
| /// For loop variables, grad is the gradient so just add an exit. | |||
| /// For loop invariants, we need to add an accumulator loop. | |||
| /// </summary> | |||
| public (object, Tensor[]) _EnterGrad(Tensor op, Tensor[] grad) | |||
| { | |||
| throw new NotImplementedException("_EnterGrad"); | |||
| // graph = ops.get_default_graph() | |||
| //# pylint: disable=protected-access | |||
| // grad_ctxt = graph._get_control_flow_context() | |||
| // # pylint: enable=protected-access | |||
| // if not grad_ctxt.back_prop: | |||
| // # Skip gradient computation, if the attribute `back_prop` is false. | |||
| // return grad | |||
| // if grad_ctxt.grad_state is None: | |||
| // # Pass the gradient through if we are not in a gradient while context. | |||
| // return grad | |||
| // if op.get_attr("is_constant"): | |||
| // # Add a gradient accumulator for each loop invariant. | |||
| // if isinstance(grad, ops.Tensor) : | |||
| // result = grad_ctxt.AddBackpropAccumulator(op, grad) | |||
| // elif isinstance(grad, ops.IndexedSlices) : | |||
| // result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad) | |||
| // else: | |||
| // # TODO(yuanbyu, lukasr): Add support for SparseTensor. | |||
| // raise TypeError("Type %s not supported" % type(grad)) | |||
| // else: | |||
| // result = exit(grad) | |||
| // grad_ctxt.loop_exits.append(result) | |||
| // grad_ctxt.ExitResult([result]) | |||
| // return result | |||
| } | |||
| public (object, Tensor[]) _RefEnterGrad(Tensor op, Tensor[] grad) | |||
| { | |||
| return _EnterGrad(op, grad); | |||
| } | |||
| /// <summary> | |||
| /// Stop backprop for the predicate of a while loop. | |||
| /// </summary> | |||
| public object _LoopCondGrad(object _) | |||
| { | |||
| return null; | |||
| } | |||
| } | |||
| } | |||
| @@ -3,13 +3,14 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class Graph | |||
| { | |||
| // Current control flow context. It could be either CondContext or WhileContext | |||
| public IControlFlowContext _control_flow_context; | |||
| public ControlFlowContext _control_flow_context; | |||
| // represents the nested with(...) statements | |||
| public List<_ControlDependenciesController> _control_dependencies_stack { get; set; } = new List<_ControlDependenciesController>(); | |||
| @@ -97,7 +98,7 @@ namespace Tensorflow | |||
| /// Returns the current control flow context. | |||
| /// </summary> | |||
| /// <returns>A context object.</returns> | |||
| public IControlFlowContext _get_control_flow_context() | |||
| public ControlFlowContext _get_control_flow_context() | |||
| { | |||
| return _control_flow_context; | |||
| } | |||
| @@ -106,7 +107,7 @@ namespace Tensorflow | |||
| /// Sets the current control flow context. | |||
| /// </summary> | |||
| /// <param name="ctx">a context object.</param> | |||
| public void _set_control_flow_context(IControlFlowContext ctx) | |||
| public void _set_control_flow_context(ControlFlowContext ctx) | |||
| { | |||
| _control_flow_context = ctx; | |||
| } | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -15,7 +16,7 @@ namespace Tensorflow | |||
| private List<ITensorOrOperation> _seen_nodes; | |||
| private List<_ControlDependenciesController> _old_stack; | |||
| private bool _new_stack; | |||
| private IControlFlowContext _old_control_flow_context; | |||
| private ControlFlowContext _old_control_flow_context; | |||
| public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Operations.ControlFlows; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| @@ -107,8 +108,8 @@ namespace Tensorflow.Operations | |||
| with(ops.control_dependencies(null), ctrl => | |||
| { | |||
| var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred); | |||
| result = new[] { r0, r1 }[_branch]; | |||
| var results = control_flow_ops._SwitchRefOrTensor(result, _pred); | |||
| result = results[_branch]; | |||
| if (_outer_context != null) | |||
| _outer_context.AddInnerOp(result.op); | |||
| }); | |||
| @@ -118,7 +119,7 @@ namespace Tensorflow.Operations | |||
| // Mark Switch output as seen by this context and any outer contexts, | |||
| // just like what we do for normal op outputs in _AddOpInternal() below. | |||
| IControlFlowContext ctxt = this; | |||
| ControlFlowContext ctxt = this; | |||
| while (ctxt != null) | |||
| { | |||
| ctxt.values.Add(result.name); | |||
| @@ -223,8 +224,8 @@ namespace Tensorflow.Operations | |||
| _values.Add(real_val.name); | |||
| _external_values[real_val.name] = real_val; | |||
| } | |||
| var (t0, t1) = control_flow_ops._SwitchRefOrTensor(real_val, _pred); | |||
| real_val = new[] {t0, t1}[_branch]; | |||
| var results = control_flow_ops._SwitchRefOrTensor(real_val, _pred); | |||
| real_val = results[_branch]; | |||
| _external_values[val.name] = real_val; | |||
| } | |||
| else | |||
| @@ -238,8 +239,8 @@ namespace Tensorflow.Operations | |||
| return real_val; | |||
| } | |||
| protected override void _AddOpInternal(Operation op) | |||
| { | |||
| protected override void _AddOpInternal(Operation op) | |||
| { | |||
| if (op.inputs.Length == 0) | |||
| { | |||
| //If we're in a while loop, remove any control inputs from outside the | |||
| @@ -282,11 +283,11 @@ namespace Tensorflow.Operations | |||
| // TODO: implement below code dependencies | |||
| //if (op.graph._is_function(op.type) || op.type == "SymbolicGradient") | |||
| // op._add_control_input(_pivot.op); | |||
| } | |||
| // Mark op's outputs as seen by this context and any outer contexts. | |||
| } | |||
| // Mark op's outputs as seen by this context and any outer contexts. | |||
| var output_names = op.outputs.Select(x => x.name).ToArray(); | |||
| IControlFlowContext ctxt = this; | |||
| ControlFlowContext ctxt = this; | |||
| while (ctxt != null) | |||
| { | |||
| foreach (var name in output_names) | |||
| @@ -298,9 +299,31 @@ namespace Tensorflow.Operations | |||
| op.graph.prevent_fetching(op); | |||
| if (_outer_context != null) | |||
| _outer_context.AddInnerOp(op); | |||
| } | |||
| _outer_context.AddInnerOp(op); | |||
| } | |||
| public override GradLoopState grad_state | |||
| { | |||
| get | |||
| { | |||
| var whc = GetWhileContext(); | |||
| if (whc != null) | |||
| return whc.grad_state; | |||
| return null; | |||
| } | |||
| } | |||
| public override bool back_prop | |||
| { | |||
| get | |||
| { | |||
| var whc = GetWhileContext(); | |||
| if (whc != null) | |||
| return whc.back_prop; | |||
| return false; | |||
| } | |||
| } | |||
| public CondContextDef to_proto(string export_scope) | |||
| { | |||
| throw new NotImplementedException(); | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Operations.ControlFlows; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| @@ -22,21 +23,25 @@ namespace Tensorflow.Operations | |||
| /// 4. A ControlFlowContext has _context_stack. | |||
| /// Pushed and popped by ctxt.Enter() and ctxt.Exit() | |||
| /// </summary> | |||
| public abstract class ControlFlowContext : Python, IPython, IControlFlowContext | |||
| public abstract class ControlFlowContext : Python, IPython | |||
| { | |||
| /// <summary> | |||
| /// The predicate tensor in this branch | |||
| /// </summary> | |||
| protected Tensor _pivot; | |||
| public Tensor pivot | |||
| { | |||
| get => _pivot; | |||
| } | |||
| protected Stack<IControlFlowContext> _context_stack; | |||
| protected IControlFlowContext _outer_context; | |||
| protected Stack<ControlFlowContext> _context_stack; | |||
| protected ControlFlowContext _outer_context; | |||
| protected Dictionary<string, ITensorOrOperation> _external_values; | |||
| public ControlFlowContext() | |||
| { | |||
| _context_stack = new Stack<IControlFlowContext>(); | |||
| _context_stack = new Stack<ControlFlowContext>(); | |||
| } | |||
| public string name { get => _name; } | |||
| @@ -111,8 +116,13 @@ namespace Tensorflow.Operations | |||
| _AddOpInternal(op); | |||
| } | |||
| public IControlFlowContext outer_context { get { return _outer_context; } } | |||
| public ControlFlowContext outer_context { get { return _outer_context; } } | |||
| public HashSet<string> values => _values; | |||
| public virtual GradLoopState grad_state => throw new NotImplementedException("abstract method"); | |||
| public virtual bool back_prop => throw new NotImplementedException("abstract method"); | |||
| public virtual Tensor AddValue(Tensor val) | |||
| { | |||
| // to be overridden | |||
| @@ -147,7 +157,7 @@ namespace Tensorflow.Operations | |||
| /// <summary> | |||
| /// Returns true if `maybe_containing_ctxt` is or contains `ctxt`. | |||
| /// </summary> | |||
| public static bool IsContainingContext(IControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt) | |||
| public static bool IsContainingContext(ControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt) | |||
| { | |||
| while (ctxt != maybe_containing_ctxt) | |||
| { | |||
| @@ -164,6 +174,16 @@ namespace Tensorflow.Operations | |||
| var internal_control_inputs = op.control_inputs; | |||
| } | |||
| /// <summary> | |||
| /// Return the while context containing this context | |||
| /// </summary> | |||
| public virtual WhileContext GetWhileContext() | |||
| { | |||
| if (_outer_context != null) | |||
| return _outer_context.GetWhileContext(); | |||
| return null; | |||
| } | |||
| public object to_proto() | |||
| { | |||
| throw new NotImplementedException(); | |||
| @@ -173,5 +193,6 @@ namespace Tensorflow.Operations | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,277 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Operations.ControlFlows | |||
| { | |||
| /// <summary> | |||
| /// Maintain the mapping from the loops to their grad states. | |||
| /// </summary> | |||
| public class ControlFlowState | |||
| { | |||
| //class ControlFlowState(object): | |||
| // """Maintain the mapping from the loops to their grad states.""" | |||
| // def __init__(self): | |||
| // self._map = {} # maps forward loop context to GradLoopState | |||
| // def GetGradState(self, op, before): | |||
| // """Return the grad state for this op if it's in a forward loop context.""" | |||
| // if before and util.IsLoopExit(op): | |||
| // forward_ctxt = op._get_control_flow_context() | |||
| // forward_ctxt = forward_ctxt.outer_context | |||
| // if forward_ctxt: | |||
| // forward_ctxt = forward_ctxt.GetWhileContext() | |||
| // else: | |||
| // forward_ctxt = _GetWhileContext(op) | |||
| // if forward_ctxt: | |||
| // return self._map.get(forward_ctxt) | |||
| // return None | |||
| // def ProcessUnusedLoopExits(self, pending_count, to_ops_set): | |||
| // """Process all the "unused" loop exits. | |||
| // The "unused" exits of the loops are added to `unused_exits`. An exit is | |||
| // unused if its pending_count is 0. If there is an exit with real gradient, | |||
| // all these deferred exits will enter the backprop loop with zero gradient. | |||
| // Otherwise, they will enter the backprop loop with None. As an example, | |||
| // people often write: | |||
| // ```python | |||
| // v1, _ = tf.while_loop(p, b, [x1, x2]) | |||
| // result = gradients(v1, x1) | |||
| // ``` | |||
| // The exit node for x2 is not included by the betweenness analysis. But we | |||
| // need to backprop x2 if x2 is involved in computing v1. | |||
| // Args: | |||
| // pending_count: The number of backprop inputs for every op. | |||
| // to_ops_set: The set of ops for ys in gradients(ys, xs) | |||
| // Returns: | |||
| // The set of unused loop exits that we know at this point we need | |||
| // to backprop. | |||
| // """ | |||
| // loop_exits = [] | |||
| // for grad_state in self._map.values(): | |||
| // for y in grad_state.forward_loop_exits: | |||
| // if pending_count[y.op] == 0: | |||
| // grad_state.pending_exits_count -= 1 | |||
| // if y.op not in to_ops_set: | |||
| // grad_state.unused_exits.append(y) | |||
| // if grad_state.pending_exits_count == 0: | |||
| // loop_exits.extend(grad_state.unused_exits) | |||
| // # Need to include Enters in backprop for higher-order gradients. | |||
| // for y in grad_state.forward_context.loop_enters: | |||
| // if pending_count[y.op] == 0: | |||
| // pending_count[y.op] = 1 | |||
| // return loop_exits | |||
| // def EnterGradWhileContext(self, op, before): | |||
| // """Enter the WhileContext for gradient computation.""" | |||
| // grad_state = self.GetGradState(op, before) | |||
| // if grad_state: | |||
| // grad_state.grad_context.Enter() | |||
| // def ExitGradWhileContext(self, op, before): | |||
| // """Exit the WhileContext for gradient computation.""" | |||
| // grad_state = self.GetGradState(op, before) | |||
| // if grad_state: | |||
| // grad_state.grad_context.Exit() | |||
| // def AddWhileContext(self, op, between_op_list, between_ops): | |||
| // """Add the grad state for the while loop that op belongs to. | |||
| // Note that op is an Exit, and this method must be called in | |||
| // the control flow context where gradients() is called. | |||
| // Note that this method modifies `between_op_list` and `between_ops`. | |||
| // """ | |||
| // forward_ctxt = _GetWhileContext(op) | |||
| // grad_state = self._map.get(forward_ctxt) | |||
| // if grad_state is None: | |||
| // # This is a new while loop so create a grad state for it. | |||
| // outer_forward_ctxt = forward_ctxt.outer_context | |||
| // if outer_forward_ctxt: | |||
| // outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() | |||
| // outer_grad_state = None | |||
| // if outer_forward_ctxt: | |||
| // outer_grad_state = self._map.get(outer_forward_ctxt) | |||
| // grad_state = GradLoopState(forward_ctxt, outer_grad_state) | |||
| // self._map[forward_ctxt] = grad_state | |||
| // # We need to include all exits of a loop for backprop. | |||
| // for loop_exit in grad_state.forward_loop_exits: | |||
| // if loop_exit.op not in between_ops: | |||
| // between_ops.add(loop_exit.op) | |||
| // between_op_list.append(loop_exit.op) | |||
| // def ZerosLikeForExit(self, val): | |||
| // """Create zeros_like gradient for a loop exit. | |||
| // If the result of a loop variable is not used but is involved in | |||
| // computing the result of some needed loop variable, we create a | |||
| // zero-valued tensor that is fed as gradient for the Exit node of that | |||
| // loop variable. Note that val.op is an Exit, and this method must be | |||
| // called in the control flow context where gradients() is called. | |||
| // Args: | |||
| // val: The output tensor of an Exit op. | |||
| // Returns: | |||
| // A zero tensor of the same shape of val. | |||
| // """ | |||
| // val_shape = val.get_shape() | |||
| // forward_ctxt = val.op._get_control_flow_context() | |||
| // outer_forward_ctxt = forward_ctxt.outer_context | |||
| // if outer_forward_ctxt: | |||
| // outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() | |||
| // outer_grad_state = None | |||
| // if outer_forward_ctxt: | |||
| // outer_grad_state = self._map.get(outer_forward_ctxt) | |||
| // if outer_grad_state: | |||
| // # This is a nested loop. | |||
| // if val_shape.is_fully_defined(): | |||
| // # If the shape is known statically, just create a zero tensor | |||
| // # with the right shape in the right context. | |||
| // outer_grad_state.grad_context.Enter() | |||
| // result = array_ops.zeros(val_shape.dims, val.dtype) | |||
| // outer_grad_state.grad_context.Exit() | |||
| // else: | |||
| // # Only the shape of value is needed for backprop. | |||
| // forward_ctxt.outer_context.Enter() | |||
| // shape = array_ops.shape_internal(val, optimize=False) | |||
| // forward_ctxt.outer_context.Exit() | |||
| // # Save the shape to a stack. | |||
| // history_shape = outer_grad_state.AddForwardAccumulator(shape) | |||
| // # Get the shape back from the stack. | |||
| // outer_grad_ctxt = outer_grad_state.grad_context | |||
| // outer_grad_ctxt.Enter() | |||
| // real_shape = outer_grad_state.AddBackpropAccumulatedValue( | |||
| // history_shape, shape) | |||
| // result = array_ops.zeros(real_shape, val.dtype) | |||
| // outer_grad_ctxt.Exit() | |||
| // else: | |||
| // # This is not a nested loop. | |||
| // if val_shape.is_fully_defined(): | |||
| // # If the shape is known statically, just create a zero tensor | |||
| // # with the right shape. | |||
| // result = array_ops.zeros(val_shape.dims, val.dtype) | |||
| // else: | |||
| // result = array_ops.zeros_like(val, optimize=False) | |||
| // return result | |||
| // def ZerosLike(self, op, index): | |||
| // """Create zeros_like for the specified output of an op. | |||
| // If op is in a while loop that is part of gradients(), this method | |||
| // must be called in its grad loop context. | |||
| // Args: | |||
| // op: A tensorflow operation. | |||
| // index: the index for a specific output of the op. | |||
| // Returns: | |||
| // A zero tensor of the same shape of op.outputs[index]. | |||
| // """ | |||
| // if util.IsLoopSwitch(op): | |||
| // return None | |||
| // if op.graph._building_function: # pylint: disable=protected-access | |||
| // # The optimization here is tricky to apply to functions | |||
| // return array_ops.zeros_like(op.outputs[index]) | |||
| // dead_branch = util.IsSwitch(op) | |||
| // forward_ctxt = _GetWhileContext(op) | |||
| // grad_state = self._map.get(forward_ctxt) | |||
| // if grad_state is None: | |||
| // # op is not in a while loop that is part of gradients(). | |||
| // return ZerosLikeOutsideLoop(op, index) | |||
| // op_ctxt = op._get_control_flow_context() | |||
| // val = ops.convert_to_tensor(op.outputs[index], name="tensor") | |||
| // shape = val.get_shape() | |||
| // if shape.is_fully_defined(): | |||
| // # If the shape is known statically, just create a zero tensor with | |||
| // # the right shape in the grad loop context. | |||
| // result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype) | |||
| // if dead_branch: | |||
| // # op is a cond switch. Guard the zero tensor with a switch. | |||
| // pred = grad_state.history_map.get(op_ctxt.pred.name) | |||
| // branch = op_ctxt.branch | |||
| // result = _SwitchRefOrTensor(result, pred)[1 - branch] | |||
| // else: | |||
| // # Unknown shape so keep a history of the shape at runtime. | |||
| // if dead_branch: | |||
| // # Need to add a special switch to guard the value. | |||
| // pred = op_ctxt.pred | |||
| // branch = op_ctxt.branch | |||
| // op_ctxt.outer_context.Enter() | |||
| // val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch] | |||
| // zeros_shape = array_ops.shape_internal(val, optimize=False) | |||
| // op_ctxt.outer_context.Exit() | |||
| // val.op._set_control_flow_context(op_ctxt) | |||
| // zeros_shape.op._set_control_flow_context(op_ctxt) | |||
| // else: | |||
| // op_ctxt.Enter() | |||
| // zeros_shape = array_ops.shape_internal(val, optimize=False) | |||
| // op_ctxt.Exit() | |||
| // # Add forward accumulator for shape. | |||
| // grad_state.grad_context.Exit() | |||
| // history_zeros_shape = grad_state.AddForwardAccumulator( | |||
| // zeros_shape, dead_branch=dead_branch) | |||
| // grad_state.grad_context.Enter() | |||
| // # Create a zero tensor with the right shape. | |||
| // shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape, | |||
| // zeros_shape, dead_branch) | |||
| // result = array_ops.zeros(shape, val.dtype) | |||
| // return result | |||
| // def PostProcessing(self): | |||
| // """Perform postprocessing at the end of gradients(). | |||
| // We have created the gradient graph at this point. So this function | |||
| // can be used to perform any postprocessing on the gradient graph. | |||
| // We currently perform the following postprocessing: | |||
| // 1. Patch the gradient graph if the output of a loop variable | |||
| // doesn't depend on its input. | |||
| // """ | |||
| // for _, grad_state in self._map.items(): | |||
| // for _, b_merge in grad_state.switch_map.items(): | |||
| // if b_merge.op.inputs[0] == b_merge.op.inputs[1]: | |||
| // # The value of this loop variable at iteration i+1 doesn't | |||
| // # depend on its value at iteration i. So use zeros as the | |||
| // # gradients for all iterations > 0. | |||
| // dtype = b_merge.op.inputs[0].dtype | |||
| // shape = b_merge.op.inputs[0].get_shape() | |||
| // # pylint: disable=protected-access | |||
| // if shape.is_fully_defined(): | |||
| // grad_state.grad_context.Enter() | |||
| // # Create a zeros and use it for iterations > 0. | |||
| // grad_val = constant_op.constant(0, dtype=dtype, shape=shape) | |||
| // next_grad_val = _NextIteration(grad_val) | |||
| // grad_state.grad_context.Exit() | |||
| // else: | |||
| // # Create a zeros in the outer grad context. | |||
| // outer_grad_ctxt = grad_state.grad_context.outer_context | |||
| // if outer_grad_ctxt: | |||
| // outer_grad_ctxt.Enter() | |||
| // enter_grad_op = b_merge.op.inputs[0].op | |||
| // enter_grad = enter_grad_op.inputs[0] | |||
| // grad_shape = array_ops.shape_internal(enter_grad, optimize=False) | |||
| // grad_val = array_ops.zeros(grad_shape) | |||
| // if outer_grad_ctxt: | |||
| // outer_grad_ctxt.Exit() | |||
| // # Use the zeros for iterations > 0. | |||
| // grad_state.grad_context.Enter() | |||
| // next_grad_val = _NextIteration(grad_val) | |||
| // grad_state.grad_context.Exit() | |||
| // b_merge.op._update_input(1, next_grad_val) | |||
| // # pylint: enable=protected-access | |||
| } | |||
| } | |||
| @@ -0,0 +1,398 @@ | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Operations.ControlFlows | |||
| { | |||
| public class GradLoopState | |||
| { | |||
| //class GradLoopState(object): | |||
| // """The state used for constructing the gradient graph for a while loop. | |||
| // We create a GradLoopState for each while loop in forward and its | |||
| // corresponding while loop in backprop. This gives us access to both | |||
| // the forward and the backprop WhileContexts. | |||
| // During the construction of gradient graph, any time when we detect | |||
| // a forward value that is needed for backprop, we create a history | |||
| // accumulator and add it to `history_map`. Any time when we backprop | |||
| // a loop switch op (in _SwitchGrad), we add the grad merge op in | |||
| // `switch_map`. | |||
| // """ | |||
| // def __init__(self, forward_ctxt, outer_grad_state): | |||
| // # The grad loop state for the outer while loop. | |||
| // self._outer_grad_state = None | |||
| // # The while loop context for forward. | |||
| // self._forward_context = None | |||
| // # The loop counter added by AddForwardLoopCounter. It is the value | |||
| // # of the loop counter for the next iteration. | |||
| // self._forward_index = None | |||
| // # A sync op for forward. | |||
| // self._forward_sync = None | |||
| // # The while loop context for backprop. | |||
| private WhileContext _grad_context = null; | |||
| public WhileContext grad_context => _grad_context; | |||
| // # The loop counter added by AddBackpropLoopCounter. It is the value | |||
| // # of the loop counter for the current iteration. | |||
| // self._grad_index = None | |||
| // # A sync op for backprop. | |||
| // self._grad_sync = None | |||
| // # Information needed by backprop. | |||
| private Hashtable _history_map = new Hashtable(); | |||
| public Hashtable history_map => _history_map; | |||
| private Hashtable _switch_map = new Hashtable(); | |||
| public Hashtable switch_map => _switch_map; | |||
| // self._unused_exits = [] | |||
| // self._deferred_exits = [] | |||
| // self._forward_loop_exits = list(forward_ctxt.loop_exits) | |||
| // self._pending_exits_count = len(forward_ctxt.loop_exits) | |||
| // self._outer_grad_state = outer_grad_state | |||
| // if outer_grad_state: | |||
| // outer_forward_ctxt = outer_grad_state.forward_context | |||
| // else: | |||
| // if not hasattr(forward_ctxt, "outer_context"): | |||
| // raise ValueError("Failed to call gradients on a while loop without" | |||
| // "properly serializing graph via MetaGraphDef") | |||
| // outer_forward_ctxt = forward_ctxt.outer_context | |||
| // # Add the forward loop counter. | |||
| // with forward_ctxt._graph.as_default(): # pylint: disable=protected-access | |||
| // if outer_forward_ctxt: | |||
| // outer_forward_ctxt.Enter() | |||
| // cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) | |||
| // if outer_forward_ctxt: | |||
| // outer_forward_ctxt.Exit() | |||
| // self._forward_context = forward_ctxt | |||
| // self._forward_index = forward_index | |||
| // # Add the backprop WhileContext, and the backprop loop counter. | |||
| // if outer_grad_state: | |||
| // # This is a nested loop. Remember the iteration counts for each | |||
| // # execution of this inner loop. | |||
| // outer_forward_ctxt.AddName(cnt.name) | |||
| // history_cnt = outer_grad_state.AddForwardAccumulator(cnt) | |||
| // outer_grad_ctxt = outer_grad_state.grad_context | |||
| // outer_grad_ctxt.Enter() | |||
| // self._grad_context = WhileContext( | |||
| // maximum_iterations=forward_ctxt.maximum_iterations, | |||
| // parallel_iterations=forward_ctxt.parallel_iterations, | |||
| // back_prop=forward_ctxt.back_prop, | |||
| // swap_memory=forward_ctxt.swap_memory, | |||
| // name=forward_ctxt.name, | |||
| // grad_state=self) | |||
| // real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt) | |||
| // self._grad_index = self._grad_context.AddBackpropLoopCounter( | |||
| // real_cnt, outer_grad_state) | |||
| // outer_grad_ctxt.Exit() | |||
| // else: | |||
| // if outer_forward_ctxt: | |||
| // outer_forward_ctxt.Enter() | |||
| // self._grad_context = WhileContext( | |||
| // maximum_iterations=forward_ctxt.maximum_iterations, | |||
| // parallel_iterations=forward_ctxt.parallel_iterations, | |||
| // back_prop=forward_ctxt.back_prop, | |||
| // swap_memory=forward_ctxt.swap_memory, | |||
| // name=forward_ctxt.name, | |||
| // grad_state=self) | |||
| // self._grad_index = self._grad_context.AddBackpropLoopCounter( | |||
| // cnt, outer_grad_state) | |||
| // if outer_forward_ctxt: | |||
| // outer_forward_ctxt.Exit() | |||
| // @property | |||
| // def outer_grad_state(self): | |||
| // """The grad loop state for outer loop.""" | |||
| // return self._outer_grad_state | |||
| // @property | |||
| // def forward_context(self): | |||
| // """The while loop context for forward.""" | |||
| // return self._forward_context | |||
| // @property | |||
| // def forward_index(self): | |||
| // """The loop index of forward loop.""" | |||
| // return self._forward_index | |||
| // @property | |||
| // def forward_sync(self): | |||
| // """A control trigger node for synchronization in the forward loop. | |||
| // One main use is to keep the push ops of a stack executed in the | |||
| // iteration order. | |||
| // """ | |||
| // if self._forward_sync is None: | |||
| // with ops.control_dependencies(None): | |||
| // self._forward_sync = control_trigger(name="f_sync") | |||
| // self._forward_sync._set_control_flow_context(self._forward_context) | |||
| // self._forward_index.op._add_control_input(self._forward_sync) | |||
| // return self._forward_sync | |||
| // @property | |||
| // def grad_context(self): | |||
| // """The corresponding WhileContext for gradient.""" | |||
| // return self._grad_context | |||
| // @property | |||
| // def grad_index(self): | |||
| // """The loop index of backprop loop.""" | |||
| // return self._grad_index | |||
| // @property | |||
| // def grad_sync(self): | |||
| // """A control trigger node for synchronization in the grad loop. | |||
| // One main use is to keep the pop ops of a stack executed in the | |||
| // iteration order. | |||
| // """ | |||
| // if self._grad_sync is None: | |||
| // with ops.control_dependencies(None): | |||
| // self._grad_sync = control_trigger(name="b_sync") | |||
| // self._grad_sync._set_control_flow_context(self._grad_context) | |||
| // self._grad_index.op._add_control_input(self._grad_sync) | |||
| // if self._grad_context.outer_context: | |||
| // self._grad_context.outer_context.AddInnerOp(self._grad_sync) | |||
| // return self._grad_sync | |||
| // @property | |||
| // def history_map(self): | |||
| // """The map that records all the tensors needed for backprop.""" | |||
| // return self._history_map | |||
| // @property | |||
| // def switch_map(self): | |||
| // """The map that records all the Switch ops for the while loop.""" | |||
| // return self._switch_map | |||
| // @property | |||
| // def unused_exits(self): | |||
| // """The list of "unused" exits.""" | |||
| // return self._unused_exits | |||
| // @property | |||
| // def deferred_exits(self): | |||
| // """The list of "deferred" exits.""" | |||
| // return self._deferred_exits | |||
| // @property | |||
| // def forward_loop_exits(self): | |||
| // """The list of exits of the forward loop.""" | |||
| // return self._forward_loop_exits | |||
| // @property | |||
| // def pending_exits_count(self): | |||
| // """The number of exits we expect to see but haven't.""" | |||
| // return self._pending_exits_count | |||
| // @pending_exits_count.setter | |||
| // def pending_exits_count(self, cnt): | |||
| // """Set the pending count to cnt.""" | |||
| // self._pending_exits_count = cnt | |||
| /// <summary> | |||
| /// Add an accumulator for each forward tensor that is needed in backprop. | |||
| /// | |||
| /// This is added to the forward loop at the first time when a tensor | |||
| /// in the forward loop is used by backprop gradient computation loop. | |||
| /// We create an accumulator that accumulates the value of tensor at each | |||
| /// iteration. Called in the control flow context where gradients() is called. | |||
| /// | |||
| /// The pseudocode is: | |||
| /// ``` | |||
| /// acc = stack(); | |||
| /// while (_pivot) { | |||
| /// acc = stack_push(acc, value); | |||
| /// } | |||
| /// ``` | |||
| /// | |||
| /// We make sure that the stack push op in one iteration is executed before | |||
| /// next iteration. This is achieved by adding a control edge from | |||
| /// `forward_index.op.inputs[0].op` to the push op, and another control | |||
| /// edge from the push op to either `forward_index.op` or `forward_sync`. | |||
| /// </summary> | |||
| /// <param name="value"> The source tensor in forward that is to be accumulated.</param> | |||
| /// <param name="dead_branch"> True iff the tensor is on a dead branch of a cond.</param> | |||
| /// <returns>The stack that contains the accumulated history of the tensor.</returns> | |||
| public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) | |||
| { | |||
| throw new NotImplementedException("AddForwardAccumulator"); | |||
| // # curr_ctxt is the context that tf.gradients was called in. | |||
| // with self._forward_index.graph.as_default(): | |||
| // curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access | |||
| // with ops.control_dependencies(None): | |||
| // if curr_ctxt: | |||
| // curr_ctxt.Enter() | |||
| // with ops.colocate_with(value): | |||
| // # We only need to pass maximum_iterations to the stack if | |||
| // # we're inside an XLA context. | |||
| // if not util.IsInXLAContext(value.op): | |||
| // max_size = constant_op.constant(-1, dtypes.int32) | |||
| // else: | |||
| // max_size = GetMaxSizeFromNestedMaximumIterations( | |||
| // value, self.forward_context) | |||
| // acc = gen_data_flow_ops.stack_v2( | |||
| // max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") | |||
| // if curr_ctxt: | |||
| // curr_ctxt.Exit() | |||
| // # Make acc available in the forward context. | |||
| // enter_acc = self.forward_context.AddValue(acc) | |||
| // # Add the stack_push op in the context of value.op. | |||
| // swap_enabled = self.forward_context.swap_memory | |||
| // value_ctxt = util.GetOutputContext(value.op) | |||
| // if value_ctxt == self.forward_context: | |||
| // # value is not nested in the forward context. | |||
| // self.forward_context.Enter() | |||
| // push = gen_data_flow_ops.stack_push_v2( | |||
| // enter_acc, value, swap_memory=swap_enabled) | |||
| // self.forward_context.Exit() | |||
| // # Protect stack push and order it before forward_index. | |||
| // self.forward_index.op._add_control_input(push.op) | |||
| // else: | |||
| // # value is in a cond context within the forward context. | |||
| // if not isinstance(value_ctxt, CondContext): | |||
| // raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) | |||
| // if dead_branch: | |||
| // # The special case for creating a zero tensor for a dead | |||
| // # branch of a switch. See ControlFlowState.ZerosLike(). | |||
| // value_ctxt.outer_context.Enter() | |||
| // push = gen_data_flow_ops.stack_push_v2( | |||
| // enter_acc, value, swap_memory=swap_enabled) | |||
| // value_ctxt.outer_context.Exit() | |||
| // push.op._set_control_flow_context(value_ctxt) | |||
| // else: | |||
| // value_ctxt.Enter() | |||
| // push = gen_data_flow_ops.stack_push_v2( | |||
| // enter_acc, value, swap_memory=swap_enabled) | |||
| // value_ctxt.Exit() | |||
| // # Protect stack push and order it before forward_sync. | |||
| // self.forward_sync._add_control_input(push.op) | |||
| // # Order stack push after the successor of forward_index | |||
| // add_op = self.forward_index.op.inputs[0].op | |||
| // push.op._add_control_input(add_op) | |||
| // return acc | |||
| } | |||
| // """Add the getter for an accumulated value in the grad context. | |||
| // | |||
| // This is added to the backprop loop. Called in the grad context to | |||
| // get the value of an accumulated value. The stack pop op must be guarded | |||
| // by the pred of the controlling cond. | |||
| // | |||
| // Args: | |||
| // history_value: The history (a stack) of a value. | |||
| // value: The value that is pushed onto the stack. | |||
| // dead_branch: True iff the tensor is on a dead branch of a cond. | |||
| // | |||
| // Returns: | |||
| // The current value (the top of the stack). | |||
| // """ | |||
| public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) | |||
| { | |||
| throw new NotImplementedException(); | |||
| // history_ctxt = history_value.op._get_control_flow_context() | |||
| // # Find the cond context that controls history_value if any. | |||
| // cond_ctxt = None | |||
| // value_ctxt = value.op._get_control_flow_context() | |||
| // while value_ctxt and value_ctxt != history_ctxt: | |||
| // if isinstance(value_ctxt, CondContext): | |||
| // cond_ctxt = value_ctxt | |||
| // break | |||
| // value_ctxt = value_ctxt.outer_context | |||
| // with ops.control_dependencies(None): | |||
| // self.grad_context.Enter() | |||
| // if cond_ctxt: | |||
| // # Guard stack pop with a switch if it is controlled by a cond. | |||
| // grad_state = self | |||
| // pred = None | |||
| // while pred is None and grad_state: | |||
| // pred = grad_state.history_map.get(cond_ctxt.pred.name) | |||
| // grad_state = grad_state.outer_grad_state | |||
| // if pred is None: | |||
| // pred = cond_ctxt.pred | |||
| // branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch | |||
| // history_value = _SwitchRefOrTensor(history_value, pred)[branch] | |||
| // pop = gen_data_flow_ops.stack_pop_v2(history_value, | |||
| // value.dtype.base_dtype) | |||
| // pop.set_shape(value.get_shape()) | |||
| // self.grad_context.Exit() | |||
| // parallel_iterations = self.grad_context.parallel_iterations | |||
| // if parallel_iterations > 1: | |||
| // # All pops are ordered after pivot_for_body and before grad_sync. | |||
| // self.grad_sync._add_control_input(pop.op) | |||
| // return pop | |||
| } | |||
| // def GetRealValue(self, value): | |||
| // """Get the real value of `value`. | |||
| // If backprop "uses" a value produced by forward inference, an accumulator | |||
| // is added in the forward loop to accumulate its values. We use the | |||
| // accumulated value. This method must be called in the grad loop context. | |||
| // `value` must be in forward and needed for backprop. | |||
| // Args: | |||
| // value: A tensor to be captured. | |||
| // Returns: | |||
| // The same tensor obtained from the saved history. | |||
| // """ | |||
| // assert value.op.type not in ["Variable", "VariableV2"] | |||
| // real_value = self._history_map.get(value.name) | |||
| // if real_value is None: | |||
| // cur_value = value | |||
| // cur_grad_state = self | |||
| // while True: | |||
| // enter_op = util.GetLoopConstantEnter(cur_value) | |||
| // if enter_op: | |||
| // # Special case: cur_value comes from a constant Enter node. | |||
| // cur_value = enter_op.inputs[0] | |||
| // cur_grad_state = cur_grad_state.outer_grad_state | |||
| // if cur_grad_state is None: | |||
| // # We are now outside all nested loops for this gradient(), | |||
| // # so `value` is a loop invariant and there is no need to | |||
| // # save the history of value. Just make cur_value to enter | |||
| // # the right control flow context. | |||
| // real_value = self._grad_context.AddValue(cur_value) | |||
| // break | |||
| // elif constant_op.is_constant(cur_value): | |||
| // # If the value to be forwarded is a constant, clone the constant in | |||
| // # the gradient loop rather than using a stack. | |||
| // # TODO(phawkins): consider hoisting the constant out of the loop | |||
| // # instead. | |||
| // real_value = constant_op.constant( | |||
| // tensor_util.constant_value(cur_value), dtype=cur_value.dtype) | |||
| // break | |||
| // else: | |||
| // # Record the history of this value in forward_ctxt. | |||
| // self._grad_context.Exit() | |||
| // history_value = cur_grad_state.AddForwardAccumulator(cur_value) | |||
| // self._grad_context.Enter() | |||
| // break | |||
| // if real_value is None: | |||
| // # Add the stack pop op in the grad context. | |||
| // real_value = cur_grad_state.AddBackpropAccumulatedValue( | |||
| // history_value, cur_value) | |||
| // if cur_grad_state != self: | |||
| // real_value = self._grad_context.AddValue(real_value) | |||
| // self._history_map[value.name] = real_value | |||
| // return real_value | |||
| } | |||
| } | |||
| @@ -4,13 +4,15 @@ using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public interface IControlFlowContext | |||
| { | |||
| void AddOp(Operation op); | |||
| IControlFlowContext outer_context { get; } | |||
| HashSet<string> values { get; } | |||
| Tensor AddValue(Tensor val); | |||
| void AddInnerOp(Operation resultOp); | |||
| object to_proto(); | |||
| } | |||
| // henon: this was too much trouble. there is no value just cost to use an interface here. | |||
| //public interface IControlFlowContext | |||
| //{ | |||
| // void AddOp(Operation op); | |||
| // IControlFlowContext outer_context { get; } | |||
| // HashSet<string> values { get; } | |||
| // Tensor pivot { get; } | |||
| // Tensor AddValue(Tensor val); | |||
| // void AddInnerOp(Operation resultOp); | |||
| // object to_proto(); | |||
| //} | |||
| } | |||
| @@ -1,11 +1,26 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Operations.ControlFlows; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| public class WhileContext : ControlFlowContext | |||
| { | |||
| private bool _back_prop=true; | |||
| private GradLoopState _grad_state =null; | |||
| public override WhileContext GetWhileContext() | |||
| { | |||
| return this; | |||
| } | |||
| public override GradLoopState grad_state => _grad_state; | |||
| public override bool back_prop => _back_prop; | |||
| public static WhileContext from_proto(object proto) | |||
| { | |||
| throw new NotImplementedException(); | |||
| @@ -7,7 +7,7 @@ namespace Tensorflow | |||
| { | |||
| public partial class Operation | |||
| { | |||
| private IControlFlowContext _control_flow_context; | |||
| private ControlFlowContext _control_flow_context; | |||
| /// <summary> | |||
| /// Add this op to its control flow context. | |||
| @@ -39,12 +39,12 @@ namespace Tensorflow | |||
| _add_control_input(op); | |||
| } | |||
| public void _set_control_flow_context(IControlFlowContext ctx) | |||
| public void _set_control_flow_context(ControlFlowContext ctx) | |||
| { | |||
| _control_flow_context = ctx; | |||
| } | |||
| public IControlFlowContext _get_control_flow_context() | |||
| public ControlFlowContext _get_control_flow_context() | |||
| { | |||
| return _control_flow_context; | |||
| } | |||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Operations.ControlFlows; | |||
| using util = Tensorflow.control_flow_util; | |||
| namespace Tensorflow | |||
| @@ -93,9 +94,9 @@ namespace Tensorflow | |||
| /// <param name="between_op_list"></param> | |||
| /// <param name="between_ops"></param> | |||
| /// <param name="colocate_gradients_with_ops"></param> | |||
| public static object MaybeCreateControlFlowState(List<Operation> between_op_list, List<Operation> between_ops, bool colocate_gradients_with_ops) | |||
| public static ControlFlowState MaybeCreateControlFlowState(List<Operation> between_op_list, List<Operation> between_ops, bool colocate_gradients_with_ops) | |||
| { | |||
| object loop_state = null; | |||
| ControlFlowState loop_state = null; | |||
| foreach (var op in between_op_list) | |||
| { | |||
| @@ -103,7 +104,7 @@ namespace Tensorflow | |||
| { | |||
| if(loop_state == null) | |||
| { | |||
| // loop_state = ControlFlowState(); | |||
| loop_state = new ControlFlowState(); | |||
| } | |||
| } | |||
| } | |||
| @@ -207,7 +208,7 @@ namespace Tensorflow | |||
| /// `(output_false, output_true)`: If `pred` is true, data will be forwarded to | |||
| /// `output_true`, otherwise it goes to `output_false`. | |||
| /// </returns> | |||
| public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") | |||
| public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") | |||
| { | |||
| data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); | |||
| // NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below | |||
| @@ -298,7 +299,9 @@ namespace Tensorflow | |||
| */ | |||
| // Add the Switch to the graph. | |||
| var (p_2, p_1) = @switch(pred, pred); | |||
| var switch_result= @switch(pred, pred); | |||
| var p_2=switch_result[0]; | |||
| var p_1 = switch_result[1]; | |||
| var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | |||
| var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | |||
| pred = array_ops.identity(pred, name: "pred_id"); | |||
| @@ -379,7 +382,9 @@ namespace Tensorflow | |||
| return with(ops.name_scope(name, "cond", new { pred }), delegate | |||
| { | |||
| // Add the Switch to the graph. | |||
| var (p_2, p_1) = @switch(pred, pred); | |||
| var switch_result = @switch(pred, pred); | |||
| var p_2 = switch_result[0]; | |||
| var p_1 = switch_result[1]; | |||
| var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | |||
| var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | |||
| pred = array_ops.identity(pred, name: "pred_id"); | |||
| @@ -460,7 +465,7 @@ namespace Tensorflow | |||
| /// <param name="pred"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <param name="name"></param> | |||
| public static (Tensor, Tensor) @switch(Tensor data, | |||
| public static Tensor[] @switch(Tensor data, | |||
| Tensor pred, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| string name = null) | |||
| @@ -30,7 +30,7 @@ namespace Tensorflow | |||
| /// <summary> | |||
| /// Return the control flow context for the output of an op. | |||
| /// </summary> | |||
| public static IControlFlowContext GetOutputContext(Operation op) | |||
| public static ControlFlowContext GetOutputContext(Operation op) | |||
| { | |||
| var ctxt = op._get_control_flow_context(); | |||
| // Exit nodes usually have a control flow context, except in the case where the | |||
| @@ -33,14 +33,14 @@ namespace Tensorflow | |||
| /// output_false: A `Tensor`. Has the same type as `data`. | |||
| /// output_true: A `Tensor`. Has the same type as `data`. | |||
| /// </returns> | |||
| public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null) | |||
| public static Tensor[] @switch(Tensor data, Tensor pred, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred }); | |||
| var _inputs_flat = _op.inputs; | |||
| var _attrs = ("T", _op.get_attr("T")); | |||
| // TODO: missing original code | |||
| //_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name); | |||
| return (_op.outputs[0], _op.outputs[1]); | |||
| return new []{_op.outputs[0], _op.outputs[1]}; | |||
| } | |||
| public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) | |||