| @@ -1,12 +1,79 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Gradients for operators defined in control_flow_ops.py.cs | |||||
| /// </summary> | |||||
| public class control_flow_grad | public class control_flow_grad | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Gradients for a Switch op is calculated using a Merge op. | |||||
| /// | |||||
| /// If the switch is a loop switch, it will be visited twice. We create | |||||
| /// the merge on the first visit, and update the other input of the merge | |||||
| /// on the second visit. A next_iteration is also added on second visit. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads) | |||||
| { | |||||
| throw new NotImplementedException("_SwitchGrad"); | |||||
| //graph = ops.get_default_graph() | |||||
| //# pylint: disable=protected-access | |||||
| //op_ctxt = op._get_control_flow_context() | |||||
| //grad_ctxt = graph._get_control_flow_context() | |||||
| //# pylint: enable=protected-access | |||||
| //if isinstance(op_ctxt, WhileContext): | |||||
| // merge_grad = grad_ctxt.grad_state.switch_map.get(op) | |||||
| // if merge_grad is not None: | |||||
| // # This is the second time this Switch is visited. It comes from | |||||
| // # the non-exit branch of the Switch, so update the second input | |||||
| // # to the Merge. | |||||
| // # TODO(yuanbyu): Perform shape inference with this new input. | |||||
| // if grad[1] is not None: | |||||
| // # pylint: disable=protected-access | |||||
| // control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1], | |||||
| // enforce_shape_invariant=False) | |||||
| // # pylint: enable=protected-access | |||||
| // return None, None | |||||
| // elif grad[0] is not None: | |||||
| // # This is the first time this Switch is visited. It comes from | |||||
| // # the Exit branch, which is grad[0]. grad[1] is empty at this point. | |||||
| // # Use grad[0] for both inputs to merge for now, but update the second | |||||
| // # input of merge when we see this Switch the second time. | |||||
| // merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] | |||||
| // grad_ctxt.grad_state.switch_map[op] = merge_grad | |||||
| // return merge_grad, None | |||||
| // else: | |||||
| // # This is the first time this Switch is visited. It comes from the | |||||
| // # Identity branch. Such a Switch has `None` gradient for the Exit branch, | |||||
| // # meaning the output is not differentiable. | |||||
| // return None, None | |||||
| //elif isinstance(op_ctxt, CondContext): | |||||
| // zero_grad = grad[1 - op_ctxt.branch] | |||||
| // # At this point, we have created zero_grad guarded by the right switch. | |||||
| // # Unfortunately, we may still get None here for not trainable data types. | |||||
| // if zero_grad is None: | |||||
| // # For resource variables we get None always on the other branch, so bypass | |||||
| // # this. | |||||
| // if op.inputs[0].dtype == dtypes.resource: | |||||
| // return merge( | |||||
| // [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None | |||||
| // return None, None | |||||
| // return merge(grad, name="cond_grad")[0], None | |||||
| //else: | |||||
| // false_grad = switch(grad[0], op.inputs[1])[0] | |||||
| // true_grad = switch(grad[1], op.inputs[1])[1] | |||||
| // return merge([false_grad, true_grad])[0], None | |||||
| } | |||||
| /// <summary> | |||||
| /// Gradients for a Merge op are calculated using a Switch op. | |||||
| /// </summary> | |||||
| public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) | public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) | ||||
| { | { | ||||
| var grad = grads[0]; | var grad = grads[0]; | ||||
| @@ -14,10 +81,164 @@ namespace Tensorflow.Gradients | |||||
| var input_op = op.inputs[0].op; | var input_op = op.inputs[0].op; | ||||
| var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
| var op_ctxt = control_flow_util.GetOutputContext(input_op); | var op_ctxt = control_flow_util.GetOutputContext(input_op); | ||||
| var pred = (op_ctxt as CondContext).pred; | |||||
| var grad_ctxt = graph._get_control_flow_context(); | |||||
| switch (op_ctxt) | |||||
| { | |||||
| case WhileContext cwhile: | |||||
| { | |||||
| return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot); | |||||
| } | |||||
| case CondContext ccond: | |||||
| { | |||||
| var pred = ccond.pred; | |||||
| if (grad_ctxt != null && grad_ctxt.grad_state != null) | |||||
| { | |||||
| //# This Merge node is part of a cond within a loop. | |||||
| //# The backprop needs to have the value of this predicate for every | |||||
| //# iteration. So we must have its values accumulated in the forward, and | |||||
| //# use the accumulated values as the predicate for this backprop switch. | |||||
| var grad_state = grad_ctxt.grad_state; | |||||
| var real_pred = grad_state.history_map[pred.name] as Tensor; | |||||
| if (real_pred == null) | |||||
| { | |||||
| //# Remember the value of pred for every iteration. | |||||
| grad_ctxt = grad_state.grad_context; | |||||
| grad_ctxt.Exit(); | |||||
| var history_pred = grad_state.AddForwardAccumulator(pred); | |||||
| grad_ctxt.Enter(); | |||||
| //# Add the stack pop op. If pred.op is in a (outer) CondContext, | |||||
| //# the stack pop will be guarded with a switch. | |||||
| real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred); | |||||
| grad_state.history_map[pred.name] = real_pred; | |||||
| } | |||||
| pred = real_pred; | |||||
| } | |||||
| var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); | |||||
| return results; | |||||
| } | |||||
| default: | |||||
| { | |||||
| var num_inputs = op.inputs.Length; | |||||
| var cond = new Tensor[num_inputs]; | |||||
| for (int i = 0; i < num_inputs; i++) | |||||
| cond[i] = math_ops.equal(op.outputs[1], i); | |||||
| var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray(); | |||||
| return result; | |||||
| } | |||||
| } | |||||
| var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); | |||||
| return new Tensor[] { results.Item1, results.Item2 }; | |||||
| } | } | ||||
| } | |||||
| public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) | |||||
| { | |||||
| return _MergeGrad(op, grads); | |||||
| } | |||||
| /// <summary> | |||||
| /// Gradients for an exit op are calculated using an Enter op. | |||||
| /// </summary> | |||||
| public Tensor[] _ExitGrad(Operation op, Tensor[] grads) | |||||
| { | |||||
| throw new NotImplementedException("_ExitGrad"); | |||||
| // graph = ops.get_default_graph() | |||||
| //# pylint: disable=protected-access | |||||
| // op_ctxt = op._get_control_flow_context() | |||||
| // grad_ctxt = graph._get_control_flow_context() | |||||
| // # pylint: enable=protected-access | |||||
| // if not grad_ctxt.back_prop: | |||||
| // # The flag `back_prop` is set by users to suppress gradient | |||||
| // # computation for this loop. If the attribute `back_prop` is false, | |||||
| // # no gradient computation. | |||||
| // return None | |||||
| // if op_ctxt.grad_state: | |||||
| // raise TypeError("Second-order gradient for while loops not supported.") | |||||
| // if isinstance(grad, ops.Tensor) : | |||||
| // grad_ctxt.AddName(grad.name) | |||||
| // else: | |||||
| // if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)): | |||||
| // raise TypeError("Type %s not supported" % type(grad)) | |||||
| // grad_ctxt.AddName(grad.values.name) | |||||
| // grad_ctxt.AddName(grad.indices.name) | |||||
| // dense_shape = grad.dense_shape | |||||
| // if dense_shape is not None: | |||||
| // grad_ctxt.AddName(dense_shape.name) | |||||
| // grad_ctxt.Enter() | |||||
| // # pylint: disable=protected-access | |||||
| // result = control_flow_ops._Enter( | |||||
| // grad, grad_ctxt.name, is_constant=False, | |||||
| // parallel_iterations=grad_ctxt.parallel_iterations, | |||||
| // name="b_exit") | |||||
| // # pylint: enable=protected-access | |||||
| // grad_ctxt.loop_enters.append(result) | |||||
| // grad_ctxt.Exit() | |||||
| // return result | |||||
| } | |||||
| /// <summary> | |||||
| /// A forward next_iteration is translated into a backprop identity. | |||||
| /// | |||||
| /// Note that the backprop next_iteration is added in switch grad. | |||||
| /// </summary> | |||||
| public (object, Tensor[]) _NextIterationGrad(object _, Tensor[] grad) | |||||
| { | |||||
| return (_, grad); | |||||
| } | |||||
| public (object, Tensor[]) _RefNextIterationGrad(object _, Tensor[] grad) | |||||
| { | |||||
| return (_, grad); | |||||
| } | |||||
| /// <summary> | |||||
| /// Gradients for an Enter are calculated using an Exit op. | |||||
| /// | |||||
| /// For loop variables, grad is the gradient so just add an exit. | |||||
| /// For loop invariants, we need to add an accumulator loop. | |||||
| /// </summary> | |||||
| public (object, Tensor[]) _EnterGrad(Tensor op, Tensor[] grad) | |||||
| { | |||||
| throw new NotImplementedException("_EnterGrad"); | |||||
| // graph = ops.get_default_graph() | |||||
| //# pylint: disable=protected-access | |||||
| // grad_ctxt = graph._get_control_flow_context() | |||||
| // # pylint: enable=protected-access | |||||
| // if not grad_ctxt.back_prop: | |||||
| // # Skip gradient computation, if the attribute `back_prop` is false. | |||||
| // return grad | |||||
| // if grad_ctxt.grad_state is None: | |||||
| // # Pass the gradient through if we are not in a gradient while context. | |||||
| // return grad | |||||
| // if op.get_attr("is_constant"): | |||||
| // # Add a gradient accumulator for each loop invariant. | |||||
| // if isinstance(grad, ops.Tensor) : | |||||
| // result = grad_ctxt.AddBackpropAccumulator(op, grad) | |||||
| // elif isinstance(grad, ops.IndexedSlices) : | |||||
| // result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad) | |||||
| // else: | |||||
| // # TODO(yuanbyu, lukasr): Add support for SparseTensor. | |||||
| // raise TypeError("Type %s not supported" % type(grad)) | |||||
| // else: | |||||
| // result = exit(grad) | |||||
| // grad_ctxt.loop_exits.append(result) | |||||
| // grad_ctxt.ExitResult([result]) | |||||
| // return result | |||||
| } | |||||
| public (object, Tensor[]) _RefEnterGrad(Tensor op, Tensor[] grad) | |||||
| { | |||||
| return _EnterGrad(op, grad); | |||||
| } | |||||
| /// <summary> | |||||
| /// Stop backprop for the predicate of a while loop. | |||||
| /// </summary> | |||||
| public object _LoopCondGrad(object _) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| @@ -3,13 +3,14 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Tensorflow.Operations; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class Graph | public partial class Graph | ||||
| { | { | ||||
| // Current control flow context. It could be either CondContext or WhileContext | // Current control flow context. It could be either CondContext or WhileContext | ||||
| public IControlFlowContext _control_flow_context; | |||||
| public ControlFlowContext _control_flow_context; | |||||
| // represents the nested with(...) statements | // represents the nested with(...) statements | ||||
| public List<_ControlDependenciesController> _control_dependencies_stack { get; set; } = new List<_ControlDependenciesController>(); | public List<_ControlDependenciesController> _control_dependencies_stack { get; set; } = new List<_ControlDependenciesController>(); | ||||
| @@ -97,7 +98,7 @@ namespace Tensorflow | |||||
| /// Returns the current control flow context. | /// Returns the current control flow context. | ||||
| /// </summary> | /// </summary> | ||||
| /// <returns>A context object.</returns> | /// <returns>A context object.</returns> | ||||
| public IControlFlowContext _get_control_flow_context() | |||||
| public ControlFlowContext _get_control_flow_context() | |||||
| { | { | ||||
| return _control_flow_context; | return _control_flow_context; | ||||
| } | } | ||||
| @@ -106,7 +107,7 @@ namespace Tensorflow | |||||
| /// Sets the current control flow context. | /// Sets the current control flow context. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx">a context object.</param> | /// <param name="ctx">a context object.</param> | ||||
| public void _set_control_flow_context(IControlFlowContext ctx) | |||||
| public void _set_control_flow_context(ControlFlowContext ctx) | |||||
| { | { | ||||
| _control_flow_context = ctx; | _control_flow_context = ctx; | ||||
| } | } | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Tensorflow.Operations; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -15,7 +16,7 @@ namespace Tensorflow | |||||
| private List<ITensorOrOperation> _seen_nodes; | private List<ITensorOrOperation> _seen_nodes; | ||||
| private List<_ControlDependenciesController> _old_stack; | private List<_ControlDependenciesController> _old_stack; | ||||
| private bool _new_stack; | private bool _new_stack; | ||||
| private IControlFlowContext _old_control_flow_context; | |||||
| private ControlFlowContext _old_control_flow_context; | |||||
| public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); | public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations.ControlFlows; | |||||
| namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
| { | { | ||||
| @@ -107,8 +108,8 @@ namespace Tensorflow.Operations | |||||
| with(ops.control_dependencies(null), ctrl => | with(ops.control_dependencies(null), ctrl => | ||||
| { | { | ||||
| var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred); | |||||
| result = new[] { r0, r1 }[_branch]; | |||||
| var results = control_flow_ops._SwitchRefOrTensor(result, _pred); | |||||
| result = results[_branch]; | |||||
| if (_outer_context != null) | if (_outer_context != null) | ||||
| _outer_context.AddInnerOp(result.op); | _outer_context.AddInnerOp(result.op); | ||||
| }); | }); | ||||
| @@ -118,7 +119,7 @@ namespace Tensorflow.Operations | |||||
| // Mark Switch output as seen by this context and any outer contexts, | // Mark Switch output as seen by this context and any outer contexts, | ||||
| // just like what we do for normal op outputs in _AddOpInternal() below. | // just like what we do for normal op outputs in _AddOpInternal() below. | ||||
| IControlFlowContext ctxt = this; | |||||
| ControlFlowContext ctxt = this; | |||||
| while (ctxt != null) | while (ctxt != null) | ||||
| { | { | ||||
| ctxt.values.Add(result.name); | ctxt.values.Add(result.name); | ||||
| @@ -223,8 +224,8 @@ namespace Tensorflow.Operations | |||||
| _values.Add(real_val.name); | _values.Add(real_val.name); | ||||
| _external_values[real_val.name] = real_val; | _external_values[real_val.name] = real_val; | ||||
| } | } | ||||
| var (t0, t1) = control_flow_ops._SwitchRefOrTensor(real_val, _pred); | |||||
| real_val = new[] {t0, t1}[_branch]; | |||||
| var results = control_flow_ops._SwitchRefOrTensor(real_val, _pred); | |||||
| real_val = results[_branch]; | |||||
| _external_values[val.name] = real_val; | _external_values[val.name] = real_val; | ||||
| } | } | ||||
| else | else | ||||
| @@ -238,8 +239,8 @@ namespace Tensorflow.Operations | |||||
| return real_val; | return real_val; | ||||
| } | } | ||||
| protected override void _AddOpInternal(Operation op) | |||||
| { | |||||
| protected override void _AddOpInternal(Operation op) | |||||
| { | |||||
| if (op.inputs.Length == 0) | if (op.inputs.Length == 0) | ||||
| { | { | ||||
| //If we're in a while loop, remove any control inputs from outside the | //If we're in a while loop, remove any control inputs from outside the | ||||
| @@ -282,11 +283,11 @@ namespace Tensorflow.Operations | |||||
| // TODO: implement below code dependencies | // TODO: implement below code dependencies | ||||
| //if (op.graph._is_function(op.type) || op.type == "SymbolicGradient") | //if (op.graph._is_function(op.type) || op.type == "SymbolicGradient") | ||||
| // op._add_control_input(_pivot.op); | // op._add_control_input(_pivot.op); | ||||
| } | |||||
| // Mark op's outputs as seen by this context and any outer contexts. | |||||
| } | |||||
| // Mark op's outputs as seen by this context and any outer contexts. | |||||
| var output_names = op.outputs.Select(x => x.name).ToArray(); | var output_names = op.outputs.Select(x => x.name).ToArray(); | ||||
| IControlFlowContext ctxt = this; | |||||
| ControlFlowContext ctxt = this; | |||||
| while (ctxt != null) | while (ctxt != null) | ||||
| { | { | ||||
| foreach (var name in output_names) | foreach (var name in output_names) | ||||
| @@ -298,9 +299,31 @@ namespace Tensorflow.Operations | |||||
| op.graph.prevent_fetching(op); | op.graph.prevent_fetching(op); | ||||
| if (_outer_context != null) | if (_outer_context != null) | ||||
| _outer_context.AddInnerOp(op); | |||||
| } | |||||
| _outer_context.AddInnerOp(op); | |||||
| } | |||||
| public override GradLoopState grad_state | |||||
| { | |||||
| get | |||||
| { | |||||
| var whc = GetWhileContext(); | |||||
| if (whc != null) | |||||
| return whc.grad_state; | |||||
| return null; | |||||
| } | |||||
| } | |||||
| public override bool back_prop | |||||
| { | |||||
| get | |||||
| { | |||||
| var whc = GetWhileContext(); | |||||
| if (whc != null) | |||||
| return whc.back_prop; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| public CondContextDef to_proto(string export_scope) | public CondContextDef to_proto(string export_scope) | ||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations.ControlFlows; | |||||
| namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
| { | { | ||||
| @@ -22,21 +23,25 @@ namespace Tensorflow.Operations | |||||
| /// 4. A ControlFlowContext has _context_stack. | /// 4. A ControlFlowContext has _context_stack. | ||||
| /// Pushed and popped by ctxt.Enter() and ctxt.Exit() | /// Pushed and popped by ctxt.Enter() and ctxt.Exit() | ||||
| /// </summary> | /// </summary> | ||||
| public abstract class ControlFlowContext : Python, IPython, IControlFlowContext | |||||
| public abstract class ControlFlowContext : Python, IPython | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// The predicate tensor in this branch | /// The predicate tensor in this branch | ||||
| /// </summary> | /// </summary> | ||||
| protected Tensor _pivot; | protected Tensor _pivot; | ||||
| public Tensor pivot | |||||
| { | |||||
| get => _pivot; | |||||
| } | |||||
| protected Stack<IControlFlowContext> _context_stack; | |||||
| protected IControlFlowContext _outer_context; | |||||
| protected Stack<ControlFlowContext> _context_stack; | |||||
| protected ControlFlowContext _outer_context; | |||||
| protected Dictionary<string, ITensorOrOperation> _external_values; | protected Dictionary<string, ITensorOrOperation> _external_values; | ||||
| public ControlFlowContext() | public ControlFlowContext() | ||||
| { | { | ||||
| _context_stack = new Stack<IControlFlowContext>(); | |||||
| _context_stack = new Stack<ControlFlowContext>(); | |||||
| } | } | ||||
| public string name { get => _name; } | public string name { get => _name; } | ||||
| @@ -111,8 +116,13 @@ namespace Tensorflow.Operations | |||||
| _AddOpInternal(op); | _AddOpInternal(op); | ||||
| } | } | ||||
| public IControlFlowContext outer_context { get { return _outer_context; } } | |||||
| public ControlFlowContext outer_context { get { return _outer_context; } } | |||||
| public HashSet<string> values => _values; | public HashSet<string> values => _values; | ||||
| public virtual GradLoopState grad_state => throw new NotImplementedException("abstract method"); | |||||
| public virtual bool back_prop => throw new NotImplementedException("abstract method"); | |||||
| public virtual Tensor AddValue(Tensor val) | public virtual Tensor AddValue(Tensor val) | ||||
| { | { | ||||
| // to be overridden | // to be overridden | ||||
| @@ -147,7 +157,7 @@ namespace Tensorflow.Operations | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns true if `maybe_containing_ctxt` is or contains `ctxt`. | /// Returns true if `maybe_containing_ctxt` is or contains `ctxt`. | ||||
| /// </summary> | /// </summary> | ||||
| public static bool IsContainingContext(IControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt) | |||||
| public static bool IsContainingContext(ControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt) | |||||
| { | { | ||||
| while (ctxt != maybe_containing_ctxt) | while (ctxt != maybe_containing_ctxt) | ||||
| { | { | ||||
| @@ -164,6 +174,16 @@ namespace Tensorflow.Operations | |||||
| var internal_control_inputs = op.control_inputs; | var internal_control_inputs = op.control_inputs; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Return the while context containing this context | |||||
| /// </summary> | |||||
| public virtual WhileContext GetWhileContext() | |||||
| { | |||||
| if (_outer_context != null) | |||||
| return _outer_context.GetWhileContext(); | |||||
| return null; | |||||
| } | |||||
| public object to_proto() | public object to_proto() | ||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| @@ -173,5 +193,6 @@ namespace Tensorflow.Operations | |||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,277 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Operations.ControlFlows | |||||
| { | |||||
| /// <summary> | |||||
| /// Maintain the mapping from the loops to their grad states. | |||||
| /// </summary> | |||||
| public class ControlFlowState | |||||
| { | |||||
| //class ControlFlowState(object): | |||||
| // """Maintain the mapping from the loops to their grad states.""" | |||||
| // def __init__(self): | |||||
| // self._map = {} # maps forward loop context to GradLoopState | |||||
| // def GetGradState(self, op, before): | |||||
| // """Return the grad state for this op if it's in a forward loop context.""" | |||||
| // if before and util.IsLoopExit(op): | |||||
| // forward_ctxt = op._get_control_flow_context() | |||||
| // forward_ctxt = forward_ctxt.outer_context | |||||
| // if forward_ctxt: | |||||
| // forward_ctxt = forward_ctxt.GetWhileContext() | |||||
| // else: | |||||
| // forward_ctxt = _GetWhileContext(op) | |||||
| // if forward_ctxt: | |||||
| // return self._map.get(forward_ctxt) | |||||
| // return None | |||||
| // def ProcessUnusedLoopExits(self, pending_count, to_ops_set): | |||||
| // """Process all the "unused" loop exits. | |||||
| // The "unused" exits of the loops are added to `unused_exits`. An exit is | |||||
| // unused if its pending_count is 0. If there is an exit with real gradient, | |||||
| // all these deferred exits will enter the backprop loop with zero gradient. | |||||
| // Otherwise, they will enter the backprop loop with None. As an example, | |||||
| // people often write: | |||||
| // ```python | |||||
| // v1, _ = tf.while_loop(p, b, [x1, x2]) | |||||
| // result = gradients(v1, x1) | |||||
| // ``` | |||||
| // The exit node for x2 is not included by the betweenness analysis. But we | |||||
| // need to backprop x2 if x2 is involved in computing v1. | |||||
| // Args: | |||||
| // pending_count: The number of backprop inputs for every op. | |||||
| // to_ops_set: The set of ops for ys in gradients(ys, xs) | |||||
| // Returns: | |||||
| // The set of unused loop exits that we know at this point we need | |||||
| // to backprop. | |||||
| // """ | |||||
| // loop_exits = [] | |||||
| // for grad_state in self._map.values(): | |||||
| // for y in grad_state.forward_loop_exits: | |||||
| // if pending_count[y.op] == 0: | |||||
| // grad_state.pending_exits_count -= 1 | |||||
| // if y.op not in to_ops_set: | |||||
| // grad_state.unused_exits.append(y) | |||||
| // if grad_state.pending_exits_count == 0: | |||||
| // loop_exits.extend(grad_state.unused_exits) | |||||
| // # Need to include Enters in backprop for higher-order gradients. | |||||
| // for y in grad_state.forward_context.loop_enters: | |||||
| // if pending_count[y.op] == 0: | |||||
| // pending_count[y.op] = 1 | |||||
| // return loop_exits | |||||
| // def EnterGradWhileContext(self, op, before): | |||||
| // """Enter the WhileContext for gradient computation.""" | |||||
| // grad_state = self.GetGradState(op, before) | |||||
| // if grad_state: | |||||
| // grad_state.grad_context.Enter() | |||||
| // def ExitGradWhileContext(self, op, before): | |||||
| // """Exit the WhileContext for gradient computation.""" | |||||
| // grad_state = self.GetGradState(op, before) | |||||
| // if grad_state: | |||||
| // grad_state.grad_context.Exit() | |||||
| // def AddWhileContext(self, op, between_op_list, between_ops): | |||||
| // """Add the grad state for the while loop that op belongs to. | |||||
| // Note that op is an Exit, and this method must be called in | |||||
| // the control flow context where gradients() is called. | |||||
| // Note that this method modifies `between_op_list` and `between_ops`. | |||||
| // """ | |||||
| // forward_ctxt = _GetWhileContext(op) | |||||
| // grad_state = self._map.get(forward_ctxt) | |||||
| // if grad_state is None: | |||||
| // # This is a new while loop so create a grad state for it. | |||||
| // outer_forward_ctxt = forward_ctxt.outer_context | |||||
| // if outer_forward_ctxt: | |||||
| // outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() | |||||
| // outer_grad_state = None | |||||
| // if outer_forward_ctxt: | |||||
| // outer_grad_state = self._map.get(outer_forward_ctxt) | |||||
| // grad_state = GradLoopState(forward_ctxt, outer_grad_state) | |||||
| // self._map[forward_ctxt] = grad_state | |||||
| // # We need to include all exits of a loop for backprop. | |||||
| // for loop_exit in grad_state.forward_loop_exits: | |||||
| // if loop_exit.op not in between_ops: | |||||
| // between_ops.add(loop_exit.op) | |||||
| // between_op_list.append(loop_exit.op) | |||||
| // def ZerosLikeForExit(self, val): | |||||
| // """Create zeros_like gradient for a loop exit. | |||||
| // If the result of a loop variable is not used but is involved in | |||||
| // computing the result of some needed loop variable, we create a | |||||
| // zero-valued tensor that is fed as gradient for the Exit node of that | |||||
| // loop variable. Note that val.op is an Exit, and this method must be | |||||
| // called in the control flow context where gradients() is called. | |||||
| // Args: | |||||
| // val: The output tensor of an Exit op. | |||||
| // Returns: | |||||
| // A zero tensor of the same shape of val. | |||||
| // """ | |||||
| // val_shape = val.get_shape() | |||||
| // forward_ctxt = val.op._get_control_flow_context() | |||||
| // outer_forward_ctxt = forward_ctxt.outer_context | |||||
| // if outer_forward_ctxt: | |||||
| // outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() | |||||
| // outer_grad_state = None | |||||
| // if outer_forward_ctxt: | |||||
| // outer_grad_state = self._map.get(outer_forward_ctxt) | |||||
| // if outer_grad_state: | |||||
| // # This is a nested loop. | |||||
| // if val_shape.is_fully_defined(): | |||||
| // # If the shape is known statically, just create a zero tensor | |||||
| // # with the right shape in the right context. | |||||
| // outer_grad_state.grad_context.Enter() | |||||
| // result = array_ops.zeros(val_shape.dims, val.dtype) | |||||
| // outer_grad_state.grad_context.Exit() | |||||
| // else: | |||||
| // # Only the shape of value is needed for backprop. | |||||
| // forward_ctxt.outer_context.Enter() | |||||
| // shape = array_ops.shape_internal(val, optimize=False) | |||||
| // forward_ctxt.outer_context.Exit() | |||||
| // # Save the shape to a stack. | |||||
| // history_shape = outer_grad_state.AddForwardAccumulator(shape) | |||||
| // # Get the shape back from the stack. | |||||
| // outer_grad_ctxt = outer_grad_state.grad_context | |||||
| // outer_grad_ctxt.Enter() | |||||
| // real_shape = outer_grad_state.AddBackpropAccumulatedValue( | |||||
| // history_shape, shape) | |||||
| // result = array_ops.zeros(real_shape, val.dtype) | |||||
| // outer_grad_ctxt.Exit() | |||||
| // else: | |||||
| // # This is not a nested loop. | |||||
| // if val_shape.is_fully_defined(): | |||||
| // # If the shape is known statically, just create a zero tensor | |||||
| // # with the right shape. | |||||
| // result = array_ops.zeros(val_shape.dims, val.dtype) | |||||
| // else: | |||||
| // result = array_ops.zeros_like(val, optimize=False) | |||||
| // return result | |||||
| // def ZerosLike(self, op, index): | |||||
| // """Create zeros_like for the specified output of an op. | |||||
| // If op is in a while loop that is part of gradients(), this method | |||||
| // must be called in its grad loop context. | |||||
| // Args: | |||||
| // op: A tensorflow operation. | |||||
| // index: the index for a specific output of the op. | |||||
| // Returns: | |||||
| // A zero tensor of the same shape of op.outputs[index]. | |||||
| // """ | |||||
| // if util.IsLoopSwitch(op): | |||||
| // return None | |||||
| // if op.graph._building_function: # pylint: disable=protected-access | |||||
| // # The optimization here is tricky to apply to functions | |||||
| // return array_ops.zeros_like(op.outputs[index]) | |||||
| // dead_branch = util.IsSwitch(op) | |||||
| // forward_ctxt = _GetWhileContext(op) | |||||
| // grad_state = self._map.get(forward_ctxt) | |||||
| // if grad_state is None: | |||||
| // # op is not in a while loop that is part of gradients(). | |||||
| // return ZerosLikeOutsideLoop(op, index) | |||||
| // op_ctxt = op._get_control_flow_context() | |||||
| // val = ops.convert_to_tensor(op.outputs[index], name="tensor") | |||||
| // shape = val.get_shape() | |||||
| // if shape.is_fully_defined(): | |||||
| // # If the shape is known statically, just create a zero tensor with | |||||
| // # the right shape in the grad loop context. | |||||
| // result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype) | |||||
| // if dead_branch: | |||||
| // # op is a cond switch. Guard the zero tensor with a switch. | |||||
| // pred = grad_state.history_map.get(op_ctxt.pred.name) | |||||
| // branch = op_ctxt.branch | |||||
| // result = _SwitchRefOrTensor(result, pred)[1 - branch] | |||||
| // else: | |||||
| // # Unknown shape so keep a history of the shape at runtime. | |||||
| // if dead_branch: | |||||
| // # Need to add a special switch to guard the value. | |||||
| // pred = op_ctxt.pred | |||||
| // branch = op_ctxt.branch | |||||
| // op_ctxt.outer_context.Enter() | |||||
| // val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch] | |||||
| // zeros_shape = array_ops.shape_internal(val, optimize=False) | |||||
| // op_ctxt.outer_context.Exit() | |||||
| // val.op._set_control_flow_context(op_ctxt) | |||||
| // zeros_shape.op._set_control_flow_context(op_ctxt) | |||||
| // else: | |||||
| // op_ctxt.Enter() | |||||
| // zeros_shape = array_ops.shape_internal(val, optimize=False) | |||||
| // op_ctxt.Exit() | |||||
| // # Add forward accumulator for shape. | |||||
| // grad_state.grad_context.Exit() | |||||
| // history_zeros_shape = grad_state.AddForwardAccumulator( | |||||
| // zeros_shape, dead_branch=dead_branch) | |||||
| // grad_state.grad_context.Enter() | |||||
| // # Create a zero tensor with the right shape. | |||||
| // shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape, | |||||
| // zeros_shape, dead_branch) | |||||
| // result = array_ops.zeros(shape, val.dtype) | |||||
| // return result | |||||
| // def PostProcessing(self): | |||||
| // """Perform postprocessing at the end of gradients(). | |||||
| // We have created the gradient graph at this point. So this function | |||||
| // can be used to perform any postprocessing on the gradient graph. | |||||
| // We currently perform the following postprocessing: | |||||
| // 1. Patch the gradient graph if the output of a loop variable | |||||
| // doesn't depend on its input. | |||||
| // """ | |||||
| // for _, grad_state in self._map.items(): | |||||
| // for _, b_merge in grad_state.switch_map.items(): | |||||
| // if b_merge.op.inputs[0] == b_merge.op.inputs[1]: | |||||
| // # The value of this loop variable at iteration i+1 doesn't | |||||
| // # depend on its value at iteration i. So use zeros as the | |||||
| // # gradients for all iterations > 0. | |||||
| // dtype = b_merge.op.inputs[0].dtype | |||||
| // shape = b_merge.op.inputs[0].get_shape() | |||||
| // # pylint: disable=protected-access | |||||
| // if shape.is_fully_defined(): | |||||
| // grad_state.grad_context.Enter() | |||||
| // # Create a zeros and use it for iterations > 0. | |||||
| // grad_val = constant_op.constant(0, dtype=dtype, shape=shape) | |||||
| // next_grad_val = _NextIteration(grad_val) | |||||
| // grad_state.grad_context.Exit() | |||||
| // else: | |||||
| // # Create a zeros in the outer grad context. | |||||
| // outer_grad_ctxt = grad_state.grad_context.outer_context | |||||
| // if outer_grad_ctxt: | |||||
| // outer_grad_ctxt.Enter() | |||||
| // enter_grad_op = b_merge.op.inputs[0].op | |||||
| // enter_grad = enter_grad_op.inputs[0] | |||||
| // grad_shape = array_ops.shape_internal(enter_grad, optimize=False) | |||||
| // grad_val = array_ops.zeros(grad_shape) | |||||
| // if outer_grad_ctxt: | |||||
| // outer_grad_ctxt.Exit() | |||||
| // # Use the zeros for iterations > 0. | |||||
| // grad_state.grad_context.Enter() | |||||
| // next_grad_val = _NextIteration(grad_val) | |||||
| // grad_state.grad_context.Exit() | |||||
| // b_merge.op._update_input(1, next_grad_val) | |||||
| // # pylint: enable=protected-access | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,398 @@ | |||||
| using System; | |||||
| using System.Collections; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Operations.ControlFlows | |||||
| { | |||||
| public class GradLoopState | |||||
| { | |||||
| //class GradLoopState(object): | |||||
| // """The state used for constructing the gradient graph for a while loop. | |||||
| // We create a GradLoopState for each while loop in forward and its | |||||
| // corresponding while loop in backprop. This gives us access to both | |||||
| // the forward and the backprop WhileContexts. | |||||
| // During the construction of gradient graph, any time when we detect | |||||
| // a forward value that is needed for backprop, we create a history | |||||
| // accumulator and add it to `history_map`. Any time when we backprop | |||||
| // a loop switch op (in _SwitchGrad), we add the grad merge op in | |||||
| // `switch_map`. | |||||
| // """ | |||||
| // def __init__(self, forward_ctxt, outer_grad_state): | |||||
| // # The grad loop state for the outer while loop. | |||||
| // self._outer_grad_state = None | |||||
| // # The while loop context for forward. | |||||
| // self._forward_context = None | |||||
| // # The loop counter added by AddForwardLoopCounter. It is the value | |||||
| // # of the loop counter for the next iteration. | |||||
| // self._forward_index = None | |||||
| // # A sync op for forward. | |||||
| // self._forward_sync = None | |||||
| // # The while loop context for backprop. | |||||
| private WhileContext _grad_context = null; | |||||
| public WhileContext grad_context => _grad_context; | |||||
| // # The loop counter added by AddBackpropLoopCounter. It is the value | |||||
| // # of the loop counter for the current iteration. | |||||
| // self._grad_index = None | |||||
| // # A sync op for backprop. | |||||
| // self._grad_sync = None | |||||
| // # Information needed by backprop. | |||||
| private Hashtable _history_map = new Hashtable(); | |||||
| public Hashtable history_map => _history_map; | |||||
| private Hashtable _switch_map = new Hashtable(); | |||||
| public Hashtable switch_map => _switch_map; | |||||
| // self._unused_exits = [] | |||||
| // self._deferred_exits = [] | |||||
| // self._forward_loop_exits = list(forward_ctxt.loop_exits) | |||||
| // self._pending_exits_count = len(forward_ctxt.loop_exits) | |||||
| // self._outer_grad_state = outer_grad_state | |||||
| // if outer_grad_state: | |||||
| // outer_forward_ctxt = outer_grad_state.forward_context | |||||
| // else: | |||||
| // if not hasattr(forward_ctxt, "outer_context"): | |||||
| // raise ValueError("Failed to call gradients on a while loop without" | |||||
| // "properly serializing graph via MetaGraphDef") | |||||
| // outer_forward_ctxt = forward_ctxt.outer_context | |||||
| // # Add the forward loop counter. | |||||
| // with forward_ctxt._graph.as_default(): # pylint: disable=protected-access | |||||
| // if outer_forward_ctxt: | |||||
| // outer_forward_ctxt.Enter() | |||||
| // cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) | |||||
| // if outer_forward_ctxt: | |||||
| // outer_forward_ctxt.Exit() | |||||
| // self._forward_context = forward_ctxt | |||||
| // self._forward_index = forward_index | |||||
| // # Add the backprop WhileContext, and the backprop loop counter. | |||||
| // if outer_grad_state: | |||||
| // # This is a nested loop. Remember the iteration counts for each | |||||
| // # execution of this inner loop. | |||||
| // outer_forward_ctxt.AddName(cnt.name) | |||||
| // history_cnt = outer_grad_state.AddForwardAccumulator(cnt) | |||||
| // outer_grad_ctxt = outer_grad_state.grad_context | |||||
| // outer_grad_ctxt.Enter() | |||||
| // self._grad_context = WhileContext( | |||||
| // maximum_iterations=forward_ctxt.maximum_iterations, | |||||
| // parallel_iterations=forward_ctxt.parallel_iterations, | |||||
| // back_prop=forward_ctxt.back_prop, | |||||
| // swap_memory=forward_ctxt.swap_memory, | |||||
| // name=forward_ctxt.name, | |||||
| // grad_state=self) | |||||
| // real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt) | |||||
| // self._grad_index = self._grad_context.AddBackpropLoopCounter( | |||||
| // real_cnt, outer_grad_state) | |||||
| // outer_grad_ctxt.Exit() | |||||
| // else: | |||||
| // if outer_forward_ctxt: | |||||
| // outer_forward_ctxt.Enter() | |||||
| // self._grad_context = WhileContext( | |||||
| // maximum_iterations=forward_ctxt.maximum_iterations, | |||||
| // parallel_iterations=forward_ctxt.parallel_iterations, | |||||
| // back_prop=forward_ctxt.back_prop, | |||||
| // swap_memory=forward_ctxt.swap_memory, | |||||
| // name=forward_ctxt.name, | |||||
| // grad_state=self) | |||||
| // self._grad_index = self._grad_context.AddBackpropLoopCounter( | |||||
| // cnt, outer_grad_state) | |||||
| // if outer_forward_ctxt: | |||||
| // outer_forward_ctxt.Exit() | |||||
| // @property | |||||
| // def outer_grad_state(self): | |||||
| // """The grad loop state for outer loop.""" | |||||
| // return self._outer_grad_state | |||||
| // @property | |||||
| // def forward_context(self): | |||||
| // """The while loop context for forward.""" | |||||
| // return self._forward_context | |||||
| // @property | |||||
| // def forward_index(self): | |||||
| // """The loop index of forward loop.""" | |||||
| // return self._forward_index | |||||
| // @property | |||||
| // def forward_sync(self): | |||||
| // """A control trigger node for synchronization in the forward loop. | |||||
| // One main use is to keep the push ops of a stack executed in the | |||||
| // iteration order. | |||||
| // """ | |||||
| // if self._forward_sync is None: | |||||
| // with ops.control_dependencies(None): | |||||
| // self._forward_sync = control_trigger(name="f_sync") | |||||
| // self._forward_sync._set_control_flow_context(self._forward_context) | |||||
| // self._forward_index.op._add_control_input(self._forward_sync) | |||||
| // return self._forward_sync | |||||
| // @property | |||||
| // def grad_context(self): | |||||
| // """The corresponding WhileContext for gradient.""" | |||||
| // return self._grad_context | |||||
| // @property | |||||
| // def grad_index(self): | |||||
| // """The loop index of backprop loop.""" | |||||
| // return self._grad_index | |||||
| // @property | |||||
| // def grad_sync(self): | |||||
| // """A control trigger node for synchronization in the grad loop. | |||||
| // One main use is to keep the pop ops of a stack executed in the | |||||
| // iteration order. | |||||
| // """ | |||||
| // if self._grad_sync is None: | |||||
| // with ops.control_dependencies(None): | |||||
| // self._grad_sync = control_trigger(name="b_sync") | |||||
| // self._grad_sync._set_control_flow_context(self._grad_context) | |||||
| // self._grad_index.op._add_control_input(self._grad_sync) | |||||
| // if self._grad_context.outer_context: | |||||
| // self._grad_context.outer_context.AddInnerOp(self._grad_sync) | |||||
| // return self._grad_sync | |||||
| // @property | |||||
| // def history_map(self): | |||||
| // """The map that records all the tensors needed for backprop.""" | |||||
| // return self._history_map | |||||
| // @property | |||||
| // def switch_map(self): | |||||
| // """The map that records all the Switch ops for the while loop.""" | |||||
| // return self._switch_map | |||||
| // @property | |||||
| // def unused_exits(self): | |||||
| // """The list of "unused" exits.""" | |||||
| // return self._unused_exits | |||||
| // @property | |||||
| // def deferred_exits(self): | |||||
| // """The list of "deferred" exits.""" | |||||
| // return self._deferred_exits | |||||
| // @property | |||||
| // def forward_loop_exits(self): | |||||
| // """The list of exits of the forward loop.""" | |||||
| // return self._forward_loop_exits | |||||
| // @property | |||||
| // def pending_exits_count(self): | |||||
| // """The number of exits we expect to see but haven't.""" | |||||
| // return self._pending_exits_count | |||||
| // @pending_exits_count.setter | |||||
| // def pending_exits_count(self, cnt): | |||||
| // """Set the pending count to cnt.""" | |||||
| // self._pending_exits_count = cnt | |||||
| /// <summary> | |||||
| /// Add an accumulator for each forward tensor that is needed in backprop. | |||||
| /// | |||||
| /// This is added to the forward loop at the first time when a tensor | |||||
| /// in the forward loop is used by backprop gradient computation loop. | |||||
| /// We create an accumulator that accumulates the value of tensor at each | |||||
| /// iteration. Called in the control flow context where gradients() is called. | |||||
| /// | |||||
| /// The pseudocode is: | |||||
| /// ``` | |||||
| /// acc = stack(); | |||||
| /// while (_pivot) { | |||||
| /// acc = stack_push(acc, value); | |||||
| /// } | |||||
| /// ``` | |||||
| /// | |||||
| /// We make sure that the stack push op in one iteration is executed before | |||||
| /// next iteration. This is achieved by adding a control edge from | |||||
| /// `forward_index.op.inputs[0].op` to the push op, and another control | |||||
| /// edge from the push op to either `forward_index.op` or `forward_sync`. | |||||
| /// </summary> | |||||
| /// <param name="value"> The source tensor in forward that is to be accumulated.</param> | |||||
| /// <param name="dead_branch"> True iff the tensor is on a dead branch of a cond.</param> | |||||
| /// <returns>The stack that contains the accumulated history of the tensor.</returns> | |||||
| public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) | |||||
| { | |||||
| throw new NotImplementedException("AddForwardAccumulator"); | |||||
| // # curr_ctxt is the context that tf.gradients was called in. | |||||
| // with self._forward_index.graph.as_default(): | |||||
| // curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access | |||||
| // with ops.control_dependencies(None): | |||||
| // if curr_ctxt: | |||||
| // curr_ctxt.Enter() | |||||
| // with ops.colocate_with(value): | |||||
| // # We only need to pass maximum_iterations to the stack if | |||||
| // # we're inside an XLA context. | |||||
| // if not util.IsInXLAContext(value.op): | |||||
| // max_size = constant_op.constant(-1, dtypes.int32) | |||||
| // else: | |||||
| // max_size = GetMaxSizeFromNestedMaximumIterations( | |||||
| // value, self.forward_context) | |||||
| // acc = gen_data_flow_ops.stack_v2( | |||||
| // max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") | |||||
| // if curr_ctxt: | |||||
| // curr_ctxt.Exit() | |||||
| // # Make acc available in the forward context. | |||||
| // enter_acc = self.forward_context.AddValue(acc) | |||||
| // # Add the stack_push op in the context of value.op. | |||||
| // swap_enabled = self.forward_context.swap_memory | |||||
| // value_ctxt = util.GetOutputContext(value.op) | |||||
| // if value_ctxt == self.forward_context: | |||||
| // # value is not nested in the forward context. | |||||
| // self.forward_context.Enter() | |||||
| // push = gen_data_flow_ops.stack_push_v2( | |||||
| // enter_acc, value, swap_memory=swap_enabled) | |||||
| // self.forward_context.Exit() | |||||
| // # Protect stack push and order it before forward_index. | |||||
| // self.forward_index.op._add_control_input(push.op) | |||||
| // else: | |||||
| // # value is in a cond context within the forward context. | |||||
| // if not isinstance(value_ctxt, CondContext): | |||||
| // raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) | |||||
| // if dead_branch: | |||||
| // # The special case for creating a zero tensor for a dead | |||||
| // # branch of a switch. See ControlFlowState.ZerosLike(). | |||||
| // value_ctxt.outer_context.Enter() | |||||
| // push = gen_data_flow_ops.stack_push_v2( | |||||
| // enter_acc, value, swap_memory=swap_enabled) | |||||
| // value_ctxt.outer_context.Exit() | |||||
| // push.op._set_control_flow_context(value_ctxt) | |||||
| // else: | |||||
| // value_ctxt.Enter() | |||||
| // push = gen_data_flow_ops.stack_push_v2( | |||||
| // enter_acc, value, swap_memory=swap_enabled) | |||||
| // value_ctxt.Exit() | |||||
| // # Protect stack push and order it before forward_sync. | |||||
| // self.forward_sync._add_control_input(push.op) | |||||
| // # Order stack push after the successor of forward_index | |||||
| // add_op = self.forward_index.op.inputs[0].op | |||||
| // push.op._add_control_input(add_op) | |||||
| // return acc | |||||
| } | |||||
| // """Add the getter for an accumulated value in the grad context. | |||||
| // | |||||
| // This is added to the backprop loop. Called in the grad context to | |||||
| // get the value of an accumulated value. The stack pop op must be guarded | |||||
| // by the pred of the controlling cond. | |||||
| // | |||||
| // Args: | |||||
| // history_value: The history (a stack) of a value. | |||||
| // value: The value that is pushed onto the stack. | |||||
| // dead_branch: True iff the tensor is on a dead branch of a cond. | |||||
| // | |||||
| // Returns: | |||||
| // The current value (the top of the stack). | |||||
| // """ | |||||
| public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| // history_ctxt = history_value.op._get_control_flow_context() | |||||
| // # Find the cond context that controls history_value if any. | |||||
| // cond_ctxt = None | |||||
| // value_ctxt = value.op._get_control_flow_context() | |||||
| // while value_ctxt and value_ctxt != history_ctxt: | |||||
| // if isinstance(value_ctxt, CondContext): | |||||
| // cond_ctxt = value_ctxt | |||||
| // break | |||||
| // value_ctxt = value_ctxt.outer_context | |||||
| // with ops.control_dependencies(None): | |||||
| // self.grad_context.Enter() | |||||
| // if cond_ctxt: | |||||
| // # Guard stack pop with a switch if it is controlled by a cond. | |||||
| // grad_state = self | |||||
| // pred = None | |||||
| // while pred is None and grad_state: | |||||
| // pred = grad_state.history_map.get(cond_ctxt.pred.name) | |||||
| // grad_state = grad_state.outer_grad_state | |||||
| // if pred is None: | |||||
| // pred = cond_ctxt.pred | |||||
| // branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch | |||||
| // history_value = _SwitchRefOrTensor(history_value, pred)[branch] | |||||
| // pop = gen_data_flow_ops.stack_pop_v2(history_value, | |||||
| // value.dtype.base_dtype) | |||||
| // pop.set_shape(value.get_shape()) | |||||
| // self.grad_context.Exit() | |||||
| // parallel_iterations = self.grad_context.parallel_iterations | |||||
| // if parallel_iterations > 1: | |||||
| // # All pops are ordered after pivot_for_body and before grad_sync. | |||||
| // self.grad_sync._add_control_input(pop.op) | |||||
| // return pop | |||||
| } | |||||
| // def GetRealValue(self, value): | |||||
| // """Get the real value of `value`. | |||||
| // If backprop "uses" a value produced by forward inference, an accumulator | |||||
| // is added in the forward loop to accumulate its values. We use the | |||||
| // accumulated value. This method must be called in the grad loop context. | |||||
| // `value` must be in forward and needed for backprop. | |||||
| // Args: | |||||
| // value: A tensor to be captured. | |||||
| // Returns: | |||||
| // The same tensor obtained from the saved history. | |||||
| // """ | |||||
| // assert value.op.type not in ["Variable", "VariableV2"] | |||||
| // real_value = self._history_map.get(value.name) | |||||
| // if real_value is None: | |||||
| // cur_value = value | |||||
| // cur_grad_state = self | |||||
| // while True: | |||||
| // enter_op = util.GetLoopConstantEnter(cur_value) | |||||
| // if enter_op: | |||||
| // # Special case: cur_value comes from a constant Enter node. | |||||
| // cur_value = enter_op.inputs[0] | |||||
| // cur_grad_state = cur_grad_state.outer_grad_state | |||||
| // if cur_grad_state is None: | |||||
| // # We are now outside all nested loops for this gradient(), | |||||
| // # so `value` is a loop invariant and there is no need to | |||||
| // # save the history of value. Just make cur_value to enter | |||||
| // # the right control flow context. | |||||
| // real_value = self._grad_context.AddValue(cur_value) | |||||
| // break | |||||
| // elif constant_op.is_constant(cur_value): | |||||
| // # If the value to be forwarded is a constant, clone the constant in | |||||
| // # the gradient loop rather than using a stack. | |||||
| // # TODO(phawkins): consider hoisting the constant out of the loop | |||||
| // # instead. | |||||
| // real_value = constant_op.constant( | |||||
| // tensor_util.constant_value(cur_value), dtype=cur_value.dtype) | |||||
| // break | |||||
| // else: | |||||
| // # Record the history of this value in forward_ctxt. | |||||
| // self._grad_context.Exit() | |||||
| // history_value = cur_grad_state.AddForwardAccumulator(cur_value) | |||||
| // self._grad_context.Enter() | |||||
| // break | |||||
| // if real_value is None: | |||||
| // # Add the stack pop op in the grad context. | |||||
| // real_value = cur_grad_state.AddBackpropAccumulatedValue( | |||||
| // history_value, cur_value) | |||||
| // if cur_grad_state != self: | |||||
| // real_value = self._grad_context.AddValue(real_value) | |||||
| // self._history_map[value.name] = real_value | |||||
| // return real_value | |||||
| } | |||||
| } | |||||
| @@ -4,13 +4,15 @@ using System.Text; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public interface IControlFlowContext | |||||
| { | |||||
| void AddOp(Operation op); | |||||
| IControlFlowContext outer_context { get; } | |||||
| HashSet<string> values { get; } | |||||
| Tensor AddValue(Tensor val); | |||||
| void AddInnerOp(Operation resultOp); | |||||
| object to_proto(); | |||||
| } | |||||
| // henon: this was too much trouble. there is no value just cost to use an interface here. | |||||
| //public interface IControlFlowContext | |||||
| //{ | |||||
| // void AddOp(Operation op); | |||||
| // IControlFlowContext outer_context { get; } | |||||
| // HashSet<string> values { get; } | |||||
| // Tensor pivot { get; } | |||||
| // Tensor AddValue(Tensor val); | |||||
| // void AddInnerOp(Operation resultOp); | |||||
| // object to_proto(); | |||||
| //} | |||||
| } | } | ||||
| @@ -1,11 +1,26 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations.ControlFlows; | |||||
| namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
| { | { | ||||
| public class WhileContext : ControlFlowContext | public class WhileContext : ControlFlowContext | ||||
| { | { | ||||
| private bool _back_prop=true; | |||||
| private GradLoopState _grad_state =null; | |||||
| public override WhileContext GetWhileContext() | |||||
| { | |||||
| return this; | |||||
| } | |||||
| public override GradLoopState grad_state => _grad_state; | |||||
| public override bool back_prop => _back_prop; | |||||
| public static WhileContext from_proto(object proto) | public static WhileContext from_proto(object proto) | ||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| @@ -7,7 +7,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class Operation | public partial class Operation | ||||
| { | { | ||||
| private IControlFlowContext _control_flow_context; | |||||
| private ControlFlowContext _control_flow_context; | |||||
| /// <summary> | /// <summary> | ||||
| /// Add this op to its control flow context. | /// Add this op to its control flow context. | ||||
| @@ -39,12 +39,12 @@ namespace Tensorflow | |||||
| _add_control_input(op); | _add_control_input(op); | ||||
| } | } | ||||
| public void _set_control_flow_context(IControlFlowContext ctx) | |||||
| public void _set_control_flow_context(ControlFlowContext ctx) | |||||
| { | { | ||||
| _control_flow_context = ctx; | _control_flow_context = ctx; | ||||
| } | } | ||||
| public IControlFlowContext _get_control_flow_context() | |||||
| public ControlFlowContext _get_control_flow_context() | |||||
| { | { | ||||
| return _control_flow_context; | return _control_flow_context; | ||||
| } | } | ||||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| using Tensorflow.Operations.ControlFlows; | |||||
| using util = Tensorflow.control_flow_util; | using util = Tensorflow.control_flow_util; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -93,9 +94,9 @@ namespace Tensorflow | |||||
| /// <param name="between_op_list"></param> | /// <param name="between_op_list"></param> | ||||
| /// <param name="between_ops"></param> | /// <param name="between_ops"></param> | ||||
| /// <param name="colocate_gradients_with_ops"></param> | /// <param name="colocate_gradients_with_ops"></param> | ||||
| public static object MaybeCreateControlFlowState(List<Operation> between_op_list, List<Operation> between_ops, bool colocate_gradients_with_ops) | |||||
| public static ControlFlowState MaybeCreateControlFlowState(List<Operation> between_op_list, List<Operation> between_ops, bool colocate_gradients_with_ops) | |||||
| { | { | ||||
| object loop_state = null; | |||||
| ControlFlowState loop_state = null; | |||||
| foreach (var op in between_op_list) | foreach (var op in between_op_list) | ||||
| { | { | ||||
| @@ -103,7 +104,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| if(loop_state == null) | if(loop_state == null) | ||||
| { | { | ||||
| // loop_state = ControlFlowState(); | |||||
| loop_state = new ControlFlowState(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -207,7 +208,7 @@ namespace Tensorflow | |||||
| /// `(output_false, output_true)`: If `pred` is true, data will be forwarded to | /// `(output_false, output_true)`: If `pred` is true, data will be forwarded to | ||||
| /// `output_true`, otherwise it goes to `output_false`. | /// `output_true`, otherwise it goes to `output_false`. | ||||
| /// </returns> | /// </returns> | ||||
| public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") | |||||
| public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") | |||||
| { | { | ||||
| data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); | data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); | ||||
| // NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below | // NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below | ||||
| @@ -298,7 +299,9 @@ namespace Tensorflow | |||||
| */ | */ | ||||
| // Add the Switch to the graph. | // Add the Switch to the graph. | ||||
| var (p_2, p_1) = @switch(pred, pred); | |||||
| var switch_result= @switch(pred, pred); | |||||
| var p_2=switch_result[0]; | |||||
| var p_1 = switch_result[1]; | |||||
| var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | ||||
| var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | ||||
| pred = array_ops.identity(pred, name: "pred_id"); | pred = array_ops.identity(pred, name: "pred_id"); | ||||
| @@ -379,7 +382,9 @@ namespace Tensorflow | |||||
| return with(ops.name_scope(name, "cond", new { pred }), delegate | return with(ops.name_scope(name, "cond", new { pred }), delegate | ||||
| { | { | ||||
| // Add the Switch to the graph. | // Add the Switch to the graph. | ||||
| var (p_2, p_1) = @switch(pred, pred); | |||||
| var switch_result = @switch(pred, pred); | |||||
| var p_2 = switch_result[0]; | |||||
| var p_1 = switch_result[1]; | |||||
| var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | ||||
| var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | ||||
| pred = array_ops.identity(pred, name: "pred_id"); | pred = array_ops.identity(pred, name: "pred_id"); | ||||
| @@ -460,7 +465,7 @@ namespace Tensorflow | |||||
| /// <param name="pred"></param> | /// <param name="pred"></param> | ||||
| /// <param name="dtype"></param> | /// <param name="dtype"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| public static (Tensor, Tensor) @switch(Tensor data, | |||||
| public static Tensor[] @switch(Tensor data, | |||||
| Tensor pred, | Tensor pred, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| string name = null) | string name = null) | ||||
| @@ -30,7 +30,7 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Return the control flow context for the output of an op. | /// Return the control flow context for the output of an op. | ||||
| /// </summary> | /// </summary> | ||||
| public static IControlFlowContext GetOutputContext(Operation op) | |||||
| public static ControlFlowContext GetOutputContext(Operation op) | |||||
| { | { | ||||
| var ctxt = op._get_control_flow_context(); | var ctxt = op._get_control_flow_context(); | ||||
| // Exit nodes usually have a control flow context, except in the case where the | // Exit nodes usually have a control flow context, except in the case where the | ||||
| @@ -33,14 +33,14 @@ namespace Tensorflow | |||||
| /// output_false: A `Tensor`. Has the same type as `data`. | /// output_false: A `Tensor`. Has the same type as `data`. | ||||
| /// output_true: A `Tensor`. Has the same type as `data`. | /// output_true: A `Tensor`. Has the same type as `data`. | ||||
| /// </returns> | /// </returns> | ||||
| public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null) | |||||
| public static Tensor[] @switch(Tensor data, Tensor pred, string name = null) | |||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred }); | var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred }); | ||||
| var _inputs_flat = _op.inputs; | var _inputs_flat = _op.inputs; | ||||
| var _attrs = ("T", _op.get_attr("T")); | var _attrs = ("T", _op.get_attr("T")); | ||||
| // TODO: missing original code | // TODO: missing original code | ||||
| //_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name); | //_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name); | ||||
| return (_op.outputs[0], _op.outputs[1]); | |||||
| return new []{_op.outputs[0], _op.outputs[1]}; | |||||
| } | } | ||||
| public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) | public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) | ||||