diff --git a/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs b/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs
index afc87d45..de61e52b 100644
--- a/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs
+++ b/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs
@@ -1,12 +1,79 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Text;
using Tensorflow.Operations;
namespace Tensorflow.Gradients
{
+ ///
+ /// Gradients for operators defined in control_flow_ops.py.cs
+ ///
public class control_flow_grad
{
+ ///
+ /// Gradients for a Switch op is calculated using a Merge op.
+ ///
+ /// If the switch is a loop switch, it will be visited twice. We create
+ /// the merge on the first visit, and update the other input of the merge
+ /// on the second visit. A next_iteration is also added on second visit.
+ ///
+ ///
+ public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads)
+ {
+ throw new NotImplementedException("_SwitchGrad");
+ //graph = ops.get_default_graph()
+ //# pylint: disable=protected-access
+ //op_ctxt = op._get_control_flow_context()
+ //grad_ctxt = graph._get_control_flow_context()
+ //# pylint: enable=protected-access
+ //if isinstance(op_ctxt, WhileContext):
+ // merge_grad = grad_ctxt.grad_state.switch_map.get(op)
+ // if merge_grad is not None:
+ // # This is the second time this Switch is visited. It comes from
+ // # the non-exit branch of the Switch, so update the second input
+ // # to the Merge.
+ // # TODO(yuanbyu): Perform shape inference with this new input.
+ // if grad[1] is not None:
+ // # pylint: disable=protected-access
+ // control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1],
+ // enforce_shape_invariant=False)
+ // # pylint: enable=protected-access
+ // return None, None
+ // elif grad[0] is not None:
+ // # This is the first time this Switch is visited. It comes from
+ // # the Exit branch, which is grad[0]. grad[1] is empty at this point.
+ // # Use grad[0] for both inputs to merge for now, but update the second
+ // # input of merge when we see this Switch the second time.
+ // merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
+ // grad_ctxt.grad_state.switch_map[op] = merge_grad
+ // return merge_grad, None
+ // else:
+ // # This is the first time this Switch is visited. It comes from the
+ // # Identity branch. Such a Switch has `None` gradient for the Exit branch,
+ // # meaning the output is not differentiable.
+ // return None, None
+ //elif isinstance(op_ctxt, CondContext):
+ // zero_grad = grad[1 - op_ctxt.branch]
+ // # At this point, we have created zero_grad guarded by the right switch.
+ // # Unfortunately, we may still get None here for not trainable data types.
+ // if zero_grad is None:
+ // # For resource variables we get None always on the other branch, so bypass
+ // # this.
+ // if op.inputs[0].dtype == dtypes.resource:
+ // return merge(
+ // [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None
+ // return None, None
+ // return merge(grad, name="cond_grad")[0], None
+ //else:
+ // false_grad = switch(grad[0], op.inputs[1])[0]
+ // true_grad = switch(grad[1], op.inputs[1])[1]
+ // return merge([false_grad, true_grad])[0], None
+ }
+
+ ///
+ /// Gradients for a Merge op are calculated using a Switch op.
+ ///
public static Tensor[] _MergeGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
@@ -14,10 +81,164 @@ namespace Tensorflow.Gradients
var input_op = op.inputs[0].op;
var graph = ops.get_default_graph();
var op_ctxt = control_flow_util.GetOutputContext(input_op);
- var pred = (op_ctxt as CondContext).pred;
+ var grad_ctxt = graph._get_control_flow_context();
+ switch (op_ctxt)
+ {
+ case WhileContext cwhile:
+ {
+ return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot);
+ }
+ case CondContext ccond:
+ {
+ var pred = ccond.pred;
+ if (grad_ctxt != null && grad_ctxt.grad_state != null)
+ {
+ //# This Merge node is part of a cond within a loop.
+ //# The backprop needs to have the value of this predicate for every
+ //# iteration. So we must have its values accumulated in the forward, and
+ //# use the accumulated values as the predicate for this backprop switch.
+ var grad_state = grad_ctxt.grad_state;
+ var real_pred = grad_state.history_map[pred.name] as Tensor;
+ if (real_pred == null)
+ {
+ //# Remember the value of pred for every iteration.
+ grad_ctxt = grad_state.grad_context;
+ grad_ctxt.Exit();
+ var history_pred = grad_state.AddForwardAccumulator(pred);
+ grad_ctxt.Enter();
+
+ //# Add the stack pop op. If pred.op is in a (outer) CondContext,
+ //# the stack pop will be guarded with a switch.
+ real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred);
+ grad_state.history_map[pred.name] = real_pred;
+ }
+ pred = real_pred;
+ }
+ var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad");
+ return results;
+ }
+ default:
+ {
+ var num_inputs = op.inputs.Length;
+ var cond = new Tensor[num_inputs];
+ for (int i = 0; i < num_inputs; i++)
+ cond[i] = math_ops.equal(op.outputs[1], i);
+ var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray();
+ return result;
+ }
+ }
- var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad");
- return new Tensor[] { results.Item1, results.Item2 };
}
- }
+
+ public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads)
+ {
+ return _MergeGrad(op, grads);
+ }
+
+ ///
+ /// Gradients for an exit op are calculated using an Enter op.
+ ///
+ public Tensor[] _ExitGrad(Operation op, Tensor[] grads)
+ {
+ throw new NotImplementedException("_ExitGrad");
+ // graph = ops.get_default_graph()
+ //# pylint: disable=protected-access
+ // op_ctxt = op._get_control_flow_context()
+ // grad_ctxt = graph._get_control_flow_context()
+ // # pylint: enable=protected-access
+ // if not grad_ctxt.back_prop:
+ // # The flag `back_prop` is set by users to suppress gradient
+ // # computation for this loop. If the attribute `back_prop` is false,
+ // # no gradient computation.
+ // return None
+
+ // if op_ctxt.grad_state:
+ // raise TypeError("Second-order gradient for while loops not supported.")
+
+ // if isinstance(grad, ops.Tensor) :
+ // grad_ctxt.AddName(grad.name)
+ // else:
+ // if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
+ // raise TypeError("Type %s not supported" % type(grad))
+ // grad_ctxt.AddName(grad.values.name)
+ // grad_ctxt.AddName(grad.indices.name)
+ // dense_shape = grad.dense_shape
+ // if dense_shape is not None:
+ // grad_ctxt.AddName(dense_shape.name)
+ // grad_ctxt.Enter()
+ // # pylint: disable=protected-access
+ // result = control_flow_ops._Enter(
+ // grad, grad_ctxt.name, is_constant=False,
+ // parallel_iterations=grad_ctxt.parallel_iterations,
+ // name="b_exit")
+ // # pylint: enable=protected-access
+ // grad_ctxt.loop_enters.append(result)
+ // grad_ctxt.Exit()
+ // return result
+ }
+
+ ///
+ /// A forward next_iteration is translated into a backprop identity.
+ ///
+ /// Note that the backprop next_iteration is added in switch grad.
+ ///
+ public (object, Tensor[]) _NextIterationGrad(object _, Tensor[] grad)
+ {
+ return (_, grad);
+ }
+
+ public (object, Tensor[]) _RefNextIterationGrad(object _, Tensor[] grad)
+ {
+ return (_, grad);
+ }
+
+ ///
+ /// Gradients for an Enter are calculated using an Exit op.
+ ///
+ /// For loop variables, grad is the gradient so just add an exit.
+ /// For loop invariants, we need to add an accumulator loop.
+ ///
+ public (object, Tensor[]) _EnterGrad(Tensor op, Tensor[] grad)
+ {
+ throw new NotImplementedException("_EnterGrad");
+ // graph = ops.get_default_graph()
+ //# pylint: disable=protected-access
+ // grad_ctxt = graph._get_control_flow_context()
+ // # pylint: enable=protected-access
+ // if not grad_ctxt.back_prop:
+ // # Skip gradient computation, if the attribute `back_prop` is false.
+ // return grad
+ // if grad_ctxt.grad_state is None:
+ // # Pass the gradient through if we are not in a gradient while context.
+ // return grad
+ // if op.get_attr("is_constant"):
+ // # Add a gradient accumulator for each loop invariant.
+ // if isinstance(grad, ops.Tensor) :
+ // result = grad_ctxt.AddBackpropAccumulator(op, grad)
+ // elif isinstance(grad, ops.IndexedSlices) :
+ // result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad)
+ // else:
+ // # TODO(yuanbyu, lukasr): Add support for SparseTensor.
+ // raise TypeError("Type %s not supported" % type(grad))
+ // else:
+ // result = exit(grad)
+ // grad_ctxt.loop_exits.append(result)
+ // grad_ctxt.ExitResult([result])
+ // return result
+ }
+
+ public (object, Tensor[]) _RefEnterGrad(Tensor op, Tensor[] grad)
+ {
+ return _EnterGrad(op, grad);
+ }
+
+ ///
+ /// Stop backprop for the predicate of a while loop.
+ ///
+ public object _LoopCondGrad(object _)
+ {
+ return null;
+ }
+
+ }
}
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
index 42cf1a17..fda9ff01 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
@@ -3,13 +3,14 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Eager;
+using Tensorflow.Operations;
namespace Tensorflow
{
public partial class Graph
{
// Current control flow context. It could be either CondContext or WhileContext
- public IControlFlowContext _control_flow_context;
+ public ControlFlowContext _control_flow_context;
// represents the nested with(...) statements
public List<_ControlDependenciesController> _control_dependencies_stack { get; set; } = new List<_ControlDependenciesController>();
@@ -97,7 +98,7 @@ namespace Tensorflow
/// Returns the current control flow context.
///
/// A context object.
- public IControlFlowContext _get_control_flow_context()
+ public ControlFlowContext _get_control_flow_context()
{
return _control_flow_context;
}
@@ -106,7 +107,7 @@ namespace Tensorflow
/// Sets the current control flow context.
///
/// a context object.
- public void _set_control_flow_context(IControlFlowContext ctx)
+ public void _set_control_flow_context(ControlFlowContext ctx)
{
_control_flow_context = ctx;
}
diff --git a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
index 36832b35..047624d5 100644
--- a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
+++ b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using Tensorflow.Eager;
+using Tensorflow.Operations;
namespace Tensorflow
{
@@ -15,7 +16,7 @@ namespace Tensorflow
private List _seen_nodes;
private List<_ControlDependenciesController> _old_stack;
private bool _new_stack;
- private IControlFlowContext _old_control_flow_context;
+ private ControlFlowContext _old_control_flow_context;
public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray();
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
index 47908e05..254df0cf 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
+using Tensorflow.Operations.ControlFlows;
namespace Tensorflow.Operations
{
@@ -107,8 +108,8 @@ namespace Tensorflow.Operations
with(ops.control_dependencies(null), ctrl =>
{
- var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred);
- result = new[] { r0, r1 }[_branch];
+ var results = control_flow_ops._SwitchRefOrTensor(result, _pred);
+ result = results[_branch];
if (_outer_context != null)
_outer_context.AddInnerOp(result.op);
});
@@ -118,7 +119,7 @@ namespace Tensorflow.Operations
// Mark Switch output as seen by this context and any outer contexts,
// just like what we do for normal op outputs in _AddOpInternal() below.
- IControlFlowContext ctxt = this;
+ ControlFlowContext ctxt = this;
while (ctxt != null)
{
ctxt.values.Add(result.name);
@@ -223,8 +224,8 @@ namespace Tensorflow.Operations
_values.Add(real_val.name);
_external_values[real_val.name] = real_val;
}
- var (t0, t1) = control_flow_ops._SwitchRefOrTensor(real_val, _pred);
- real_val = new[] {t0, t1}[_branch];
+ var results = control_flow_ops._SwitchRefOrTensor(real_val, _pred);
+ real_val = results[_branch];
_external_values[val.name] = real_val;
}
else
@@ -238,8 +239,8 @@ namespace Tensorflow.Operations
return real_val;
}
- protected override void _AddOpInternal(Operation op)
- {
+ protected override void _AddOpInternal(Operation op)
+ {
if (op.inputs.Length == 0)
{
//If we're in a while loop, remove any control inputs from outside the
@@ -282,11 +283,11 @@ namespace Tensorflow.Operations
// TODO: implement below code dependencies
//if (op.graph._is_function(op.type) || op.type == "SymbolicGradient")
// op._add_control_input(_pivot.op);
- }
-
- // Mark op's outputs as seen by this context and any outer contexts.
+ }
+
+ // Mark op's outputs as seen by this context and any outer contexts.
var output_names = op.outputs.Select(x => x.name).ToArray();
- IControlFlowContext ctxt = this;
+ ControlFlowContext ctxt = this;
while (ctxt != null)
{
foreach (var name in output_names)
@@ -298,9 +299,31 @@ namespace Tensorflow.Operations
op.graph.prevent_fetching(op);
if (_outer_context != null)
- _outer_context.AddInnerOp(op);
- }
-
+ _outer_context.AddInnerOp(op);
+ }
+
+ public override GradLoopState grad_state
+ {
+ get
+ {
+ var whc = GetWhileContext();
+ if (whc != null)
+ return whc.grad_state;
+ return null;
+ }
+ }
+
+ public override bool back_prop
+ {
+ get
+ {
+ var whc = GetWhileContext();
+ if (whc != null)
+ return whc.back_prop;
+ return false;
+ }
+ }
+
public CondContextDef to_proto(string export_scope)
{
throw new NotImplementedException();
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
index 56b38846..48a519db 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
+using Tensorflow.Operations.ControlFlows;
namespace Tensorflow.Operations
{
@@ -22,21 +23,25 @@ namespace Tensorflow.Operations
/// 4. A ControlFlowContext has _context_stack.
/// Pushed and popped by ctxt.Enter() and ctxt.Exit()
///
- public abstract class ControlFlowContext : Python, IPython, IControlFlowContext
+ public abstract class ControlFlowContext : Python, IPython
{
///
/// The predicate tensor in this branch
///
protected Tensor _pivot;
+ public Tensor pivot
+ {
+ get => _pivot;
+ }
- protected Stack _context_stack;
- protected IControlFlowContext _outer_context;
+ protected Stack _context_stack;
+ protected ControlFlowContext _outer_context;
protected Dictionary _external_values;
public ControlFlowContext()
{
- _context_stack = new Stack();
+ _context_stack = new Stack();
}
public string name { get => _name; }
@@ -111,8 +116,13 @@ namespace Tensorflow.Operations
_AddOpInternal(op);
}
- public IControlFlowContext outer_context { get { return _outer_context; } }
+ public ControlFlowContext outer_context { get { return _outer_context; } }
public HashSet values => _values;
+
+ public virtual GradLoopState grad_state => throw new NotImplementedException("abstract method");
+
+ public virtual bool back_prop => throw new NotImplementedException("abstract method");
+
public virtual Tensor AddValue(Tensor val)
{
// to be overridden
@@ -147,7 +157,7 @@ namespace Tensorflow.Operations
///
/// Returns true if `maybe_containing_ctxt` is or contains `ctxt`.
///
- public static bool IsContainingContext(IControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt)
+ public static bool IsContainingContext(ControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt)
{
while (ctxt != maybe_containing_ctxt)
{
@@ -164,6 +174,16 @@ namespace Tensorflow.Operations
var internal_control_inputs = op.control_inputs;
}
+ ///
+ /// Return the while context containing this context
+ ///
+ public virtual WhileContext GetWhileContext()
+ {
+ if (_outer_context != null)
+ return _outer_context.GetWhileContext();
+ return null;
+ }
+
public object to_proto()
{
throw new NotImplementedException();
@@ -173,5 +193,6 @@ namespace Tensorflow.Operations
public void Dispose()
{
}
+
}
}
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs
new file mode 100644
index 00000000..c87ba1c6
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs
@@ -0,0 +1,277 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Operations.ControlFlows
+{
+ ///
+ /// Maintain the mapping from the loops to their grad states.
+ ///
+ public class ControlFlowState
+ {
+ //class ControlFlowState(object):
+ // """Maintain the mapping from the loops to their grad states."""
+
+ // def __init__(self):
+ // self._map = {} # maps forward loop context to GradLoopState
+
+ // def GetGradState(self, op, before):
+ // """Return the grad state for this op if it's in a forward loop context."""
+ // if before and util.IsLoopExit(op):
+ // forward_ctxt = op._get_control_flow_context()
+ // forward_ctxt = forward_ctxt.outer_context
+ // if forward_ctxt:
+ // forward_ctxt = forward_ctxt.GetWhileContext()
+ // else:
+ // forward_ctxt = _GetWhileContext(op)
+ // if forward_ctxt:
+ // return self._map.get(forward_ctxt)
+ // return None
+
+ // def ProcessUnusedLoopExits(self, pending_count, to_ops_set):
+ // """Process all the "unused" loop exits.
+
+ // The "unused" exits of the loops are added to `unused_exits`. An exit is
+ // unused if its pending_count is 0. If there is an exit with real gradient,
+ // all these deferred exits will enter the backprop loop with zero gradient.
+ // Otherwise, they will enter the backprop loop with None. As an example,
+ // people often write:
+
+ // ```python
+ // v1, _ = tf.while_loop(p, b, [x1, x2])
+ // result = gradients(v1, x1)
+ // ```
+
+ // The exit node for x2 is not included by the betweenness analysis. But we
+ // need to backprop x2 if x2 is involved in computing v1.
+
+ // Args:
+ // pending_count: The number of backprop inputs for every op.
+ // to_ops_set: The set of ops for ys in gradients(ys, xs)
+
+ // Returns:
+ // The set of unused loop exits that we know at this point we need
+ // to backprop.
+ // """
+ // loop_exits = []
+ // for grad_state in self._map.values():
+ // for y in grad_state.forward_loop_exits:
+ // if pending_count[y.op] == 0:
+ // grad_state.pending_exits_count -= 1
+ // if y.op not in to_ops_set:
+ // grad_state.unused_exits.append(y)
+ // if grad_state.pending_exits_count == 0:
+ // loop_exits.extend(grad_state.unused_exits)
+ // # Need to include Enters in backprop for higher-order gradients.
+ // for y in grad_state.forward_context.loop_enters:
+ // if pending_count[y.op] == 0:
+ // pending_count[y.op] = 1
+ // return loop_exits
+
+ // def EnterGradWhileContext(self, op, before):
+ // """Enter the WhileContext for gradient computation."""
+ // grad_state = self.GetGradState(op, before)
+ // if grad_state:
+ // grad_state.grad_context.Enter()
+
+ // def ExitGradWhileContext(self, op, before):
+ // """Exit the WhileContext for gradient computation."""
+ // grad_state = self.GetGradState(op, before)
+ // if grad_state:
+ // grad_state.grad_context.Exit()
+
+ // def AddWhileContext(self, op, between_op_list, between_ops):
+ // """Add the grad state for the while loop that op belongs to.
+
+ // Note that op is an Exit, and this method must be called in
+ // the control flow context where gradients() is called.
+
+ // Note that this method modifies `between_op_list` and `between_ops`.
+ // """
+ // forward_ctxt = _GetWhileContext(op)
+ // grad_state = self._map.get(forward_ctxt)
+ // if grad_state is None:
+ // # This is a new while loop so create a grad state for it.
+ // outer_forward_ctxt = forward_ctxt.outer_context
+ // if outer_forward_ctxt:
+ // outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
+ // outer_grad_state = None
+ // if outer_forward_ctxt:
+ // outer_grad_state = self._map.get(outer_forward_ctxt)
+ // grad_state = GradLoopState(forward_ctxt, outer_grad_state)
+ // self._map[forward_ctxt] = grad_state
+
+ // # We need to include all exits of a loop for backprop.
+ // for loop_exit in grad_state.forward_loop_exits:
+ // if loop_exit.op not in between_ops:
+ // between_ops.add(loop_exit.op)
+ // between_op_list.append(loop_exit.op)
+
+ // def ZerosLikeForExit(self, val):
+ // """Create zeros_like gradient for a loop exit.
+
+ // If the result of a loop variable is not used but is involved in
+ // computing the result of some needed loop variable, we create a
+ // zero-valued tensor that is fed as gradient for the Exit node of that
+ // loop variable. Note that val.op is an Exit, and this method must be
+ // called in the control flow context where gradients() is called.
+
+ // Args:
+ // val: The output tensor of an Exit op.
+
+ // Returns:
+ // A zero tensor of the same shape of val.
+ // """
+ // val_shape = val.get_shape()
+ // forward_ctxt = val.op._get_control_flow_context()
+ // outer_forward_ctxt = forward_ctxt.outer_context
+ // if outer_forward_ctxt:
+ // outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
+ // outer_grad_state = None
+ // if outer_forward_ctxt:
+ // outer_grad_state = self._map.get(outer_forward_ctxt)
+ // if outer_grad_state:
+ // # This is a nested loop.
+ // if val_shape.is_fully_defined():
+ // # If the shape is known statically, just create a zero tensor
+ // # with the right shape in the right context.
+ // outer_grad_state.grad_context.Enter()
+ // result = array_ops.zeros(val_shape.dims, val.dtype)
+ // outer_grad_state.grad_context.Exit()
+ // else:
+ // # Only the shape of value is needed for backprop.
+ // forward_ctxt.outer_context.Enter()
+ // shape = array_ops.shape_internal(val, optimize=False)
+ // forward_ctxt.outer_context.Exit()
+ // # Save the shape to a stack.
+ // history_shape = outer_grad_state.AddForwardAccumulator(shape)
+ // # Get the shape back from the stack.
+ // outer_grad_ctxt = outer_grad_state.grad_context
+ // outer_grad_ctxt.Enter()
+ // real_shape = outer_grad_state.AddBackpropAccumulatedValue(
+ // history_shape, shape)
+ // result = array_ops.zeros(real_shape, val.dtype)
+ // outer_grad_ctxt.Exit()
+ // else:
+ // # This is not a nested loop.
+ // if val_shape.is_fully_defined():
+ // # If the shape is known statically, just create a zero tensor
+ // # with the right shape.
+ // result = array_ops.zeros(val_shape.dims, val.dtype)
+ // else:
+ // result = array_ops.zeros_like(val, optimize=False)
+ // return result
+
+ // def ZerosLike(self, op, index):
+ // """Create zeros_like for the specified output of an op.
+
+ // If op is in a while loop that is part of gradients(), this method
+ // must be called in its grad loop context.
+
+ // Args:
+ // op: A tensorflow operation.
+ // index: the index for a specific output of the op.
+
+ // Returns:
+ // A zero tensor of the same shape of op.outputs[index].
+ // """
+ // if util.IsLoopSwitch(op):
+ // return None
+ // if op.graph._building_function: # pylint: disable=protected-access
+ // # The optimization here is tricky to apply to functions
+ // return array_ops.zeros_like(op.outputs[index])
+ // dead_branch = util.IsSwitch(op)
+ // forward_ctxt = _GetWhileContext(op)
+ // grad_state = self._map.get(forward_ctxt)
+ // if grad_state is None:
+ // # op is not in a while loop that is part of gradients().
+ // return ZerosLikeOutsideLoop(op, index)
+ // op_ctxt = op._get_control_flow_context()
+ // val = ops.convert_to_tensor(op.outputs[index], name="tensor")
+ // shape = val.get_shape()
+ // if shape.is_fully_defined():
+ // # If the shape is known statically, just create a zero tensor with
+ // # the right shape in the grad loop context.
+ // result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
+ // if dead_branch:
+ // # op is a cond switch. Guard the zero tensor with a switch.
+ // pred = grad_state.history_map.get(op_ctxt.pred.name)
+ // branch = op_ctxt.branch
+ // result = _SwitchRefOrTensor(result, pred)[1 - branch]
+ // else:
+ // # Unknown shape so keep a history of the shape at runtime.
+ // if dead_branch:
+ // # Need to add a special switch to guard the value.
+ // pred = op_ctxt.pred
+ // branch = op_ctxt.branch
+ // op_ctxt.outer_context.Enter()
+ // val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch]
+ // zeros_shape = array_ops.shape_internal(val, optimize=False)
+ // op_ctxt.outer_context.Exit()
+ // val.op._set_control_flow_context(op_ctxt)
+ // zeros_shape.op._set_control_flow_context(op_ctxt)
+ // else:
+ // op_ctxt.Enter()
+ // zeros_shape = array_ops.shape_internal(val, optimize=False)
+ // op_ctxt.Exit()
+
+ // # Add forward accumulator for shape.
+ // grad_state.grad_context.Exit()
+ // history_zeros_shape = grad_state.AddForwardAccumulator(
+ // zeros_shape, dead_branch=dead_branch)
+ // grad_state.grad_context.Enter()
+
+ // # Create a zero tensor with the right shape.
+ // shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
+ // zeros_shape, dead_branch)
+ // result = array_ops.zeros(shape, val.dtype)
+ // return result
+
+ // def PostProcessing(self):
+ // """Perform postprocessing at the end of gradients().
+
+ // We have created the gradient graph at this point. So this function
+ // can be used to perform any postprocessing on the gradient graph.
+ // We currently perform the following postprocessing:
+ // 1. Patch the gradient graph if the output of a loop variable
+ // doesn't depend on its input.
+ // """
+ // for _, grad_state in self._map.items():
+ // for _, b_merge in grad_state.switch_map.items():
+ // if b_merge.op.inputs[0] == b_merge.op.inputs[1]:
+ // # The value of this loop variable at iteration i+1 doesn't
+ // # depend on its value at iteration i. So use zeros as the
+ // # gradients for all iterations > 0.
+ // dtype = b_merge.op.inputs[0].dtype
+ // shape = b_merge.op.inputs[0].get_shape()
+ // # pylint: disable=protected-access
+ // if shape.is_fully_defined():
+ // grad_state.grad_context.Enter()
+ // # Create a zeros and use it for iterations > 0.
+ // grad_val = constant_op.constant(0, dtype=dtype, shape=shape)
+ // next_grad_val = _NextIteration(grad_val)
+ // grad_state.grad_context.Exit()
+ // else:
+ // # Create a zeros in the outer grad context.
+ // outer_grad_ctxt = grad_state.grad_context.outer_context
+ // if outer_grad_ctxt:
+ // outer_grad_ctxt.Enter()
+ // enter_grad_op = b_merge.op.inputs[0].op
+ // enter_grad = enter_grad_op.inputs[0]
+ // grad_shape = array_ops.shape_internal(enter_grad, optimize=False)
+ // grad_val = array_ops.zeros(grad_shape)
+ // if outer_grad_ctxt:
+ // outer_grad_ctxt.Exit()
+ // # Use the zeros for iterations > 0.
+ // grad_state.grad_context.Enter()
+ // next_grad_val = _NextIteration(grad_val)
+ // grad_state.grad_context.Exit()
+ // b_merge.op._update_input(1, next_grad_val)
+ // # pylint: enable=protected-access
+
+ }
+
+
+
+
+}
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs
new file mode 100644
index 00000000..e8fda1a0
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs
@@ -0,0 +1,398 @@
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Operations.ControlFlows
+{
+ public class GradLoopState
+ {
+
+ //class GradLoopState(object):
+ // """The state used for constructing the gradient graph for a while loop.
+
+ // We create a GradLoopState for each while loop in forward and its
+ // corresponding while loop in backprop. This gives us access to both
+ // the forward and the backprop WhileContexts.
+
+ // During the construction of gradient graph, any time when we detect
+ // a forward value that is needed for backprop, we create a history
+ // accumulator and add it to `history_map`. Any time when we backprop
+ // a loop switch op (in _SwitchGrad), we add the grad merge op in
+ // `switch_map`.
+ // """
+
+ // def __init__(self, forward_ctxt, outer_grad_state):
+ // # The grad loop state for the outer while loop.
+ // self._outer_grad_state = None
+
+ // # The while loop context for forward.
+ // self._forward_context = None
+
+ // # The loop counter added by AddForwardLoopCounter. It is the value
+ // # of the loop counter for the next iteration.
+ // self._forward_index = None
+
+ // # A sync op for forward.
+ // self._forward_sync = None
+
+ // # The while loop context for backprop.
+ private WhileContext _grad_context = null;
+
+ public WhileContext grad_context => _grad_context;
+
+ // # The loop counter added by AddBackpropLoopCounter. It is the value
+ // # of the loop counter for the current iteration.
+ // self._grad_index = None
+
+ // # A sync op for backprop.
+ // self._grad_sync = None
+
+ // # Information needed by backprop.
+ private Hashtable _history_map = new Hashtable();
+ public Hashtable history_map => _history_map;
+ private Hashtable _switch_map = new Hashtable();
+ public Hashtable switch_map => _switch_map;
+ // self._unused_exits = []
+ // self._deferred_exits = []
+ // self._forward_loop_exits = list(forward_ctxt.loop_exits)
+ // self._pending_exits_count = len(forward_ctxt.loop_exits)
+
+ // self._outer_grad_state = outer_grad_state
+ // if outer_grad_state:
+ // outer_forward_ctxt = outer_grad_state.forward_context
+ // else:
+ // if not hasattr(forward_ctxt, "outer_context"):
+ // raise ValueError("Failed to call gradients on a while loop without"
+ // "properly serializing graph via MetaGraphDef")
+ // outer_forward_ctxt = forward_ctxt.outer_context
+
+ // # Add the forward loop counter.
+ // with forward_ctxt._graph.as_default(): # pylint: disable=protected-access
+ // if outer_forward_ctxt:
+ // outer_forward_ctxt.Enter()
+ // cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
+ // if outer_forward_ctxt:
+ // outer_forward_ctxt.Exit()
+ // self._forward_context = forward_ctxt
+ // self._forward_index = forward_index
+
+ // # Add the backprop WhileContext, and the backprop loop counter.
+ // if outer_grad_state:
+ // # This is a nested loop. Remember the iteration counts for each
+ // # execution of this inner loop.
+ // outer_forward_ctxt.AddName(cnt.name)
+ // history_cnt = outer_grad_state.AddForwardAccumulator(cnt)
+
+ // outer_grad_ctxt = outer_grad_state.grad_context
+ // outer_grad_ctxt.Enter()
+ // self._grad_context = WhileContext(
+ // maximum_iterations=forward_ctxt.maximum_iterations,
+ // parallel_iterations=forward_ctxt.parallel_iterations,
+ // back_prop=forward_ctxt.back_prop,
+ // swap_memory=forward_ctxt.swap_memory,
+ // name=forward_ctxt.name,
+ // grad_state=self)
+ // real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt)
+ // self._grad_index = self._grad_context.AddBackpropLoopCounter(
+ // real_cnt, outer_grad_state)
+ // outer_grad_ctxt.Exit()
+ // else:
+ // if outer_forward_ctxt:
+ // outer_forward_ctxt.Enter()
+ // self._grad_context = WhileContext(
+ // maximum_iterations=forward_ctxt.maximum_iterations,
+ // parallel_iterations=forward_ctxt.parallel_iterations,
+ // back_prop=forward_ctxt.back_prop,
+ // swap_memory=forward_ctxt.swap_memory,
+ // name=forward_ctxt.name,
+ // grad_state=self)
+ // self._grad_index = self._grad_context.AddBackpropLoopCounter(
+ // cnt, outer_grad_state)
+ // if outer_forward_ctxt:
+ // outer_forward_ctxt.Exit()
+
+ // @property
+ // def outer_grad_state(self):
+ // """The grad loop state for outer loop."""
+ // return self._outer_grad_state
+
+ // @property
+ // def forward_context(self):
+ // """The while loop context for forward."""
+ // return self._forward_context
+
+ // @property
+ // def forward_index(self):
+ // """The loop index of forward loop."""
+ // return self._forward_index
+
+ // @property
+ // def forward_sync(self):
+ // """A control trigger node for synchronization in the forward loop.
+
+ // One main use is to keep the push ops of a stack executed in the
+ // iteration order.
+ // """
+ // if self._forward_sync is None:
+ // with ops.control_dependencies(None):
+ // self._forward_sync = control_trigger(name="f_sync")
+ // self._forward_sync._set_control_flow_context(self._forward_context)
+ // self._forward_index.op._add_control_input(self._forward_sync)
+ // return self._forward_sync
+
+ // @property
+ // def grad_context(self):
+ // """The corresponding WhileContext for gradient."""
+ // return self._grad_context
+
+ // @property
+ // def grad_index(self):
+ // """The loop index of backprop loop."""
+ // return self._grad_index
+
+ // @property
+ // def grad_sync(self):
+ // """A control trigger node for synchronization in the grad loop.
+
+ // One main use is to keep the pop ops of a stack executed in the
+ // iteration order.
+ // """
+ // if self._grad_sync is None:
+ // with ops.control_dependencies(None):
+ // self._grad_sync = control_trigger(name="b_sync")
+ // self._grad_sync._set_control_flow_context(self._grad_context)
+ // self._grad_index.op._add_control_input(self._grad_sync)
+ // if self._grad_context.outer_context:
+ // self._grad_context.outer_context.AddInnerOp(self._grad_sync)
+ // return self._grad_sync
+
+ // @property
+ // def history_map(self):
+ // """The map that records all the tensors needed for backprop."""
+ // return self._history_map
+
+ // @property
+ // def switch_map(self):
+ // """The map that records all the Switch ops for the while loop."""
+ // return self._switch_map
+
+ // @property
+ // def unused_exits(self):
+ // """The list of "unused" exits."""
+ // return self._unused_exits
+
+ // @property
+ // def deferred_exits(self):
+ // """The list of "deferred" exits."""
+ // return self._deferred_exits
+
+ // @property
+ // def forward_loop_exits(self):
+ // """The list of exits of the forward loop."""
+ // return self._forward_loop_exits
+
+ // @property
+ // def pending_exits_count(self):
+ // """The number of exits we expect to see but haven't."""
+ // return self._pending_exits_count
+
+ // @pending_exits_count.setter
+ // def pending_exits_count(self, cnt):
+ // """Set the pending count to cnt."""
+ // self._pending_exits_count = cnt
+
+ ///
+ /// Add an accumulator for each forward tensor that is needed in backprop.
+ ///
+ /// This is added to the forward loop at the first time when a tensor
+ /// in the forward loop is used by backprop gradient computation loop.
+ /// We create an accumulator that accumulates the value of tensor at each
+ /// iteration. Called in the control flow context where gradients() is called.
+ ///
+ /// The pseudocode is:
+ /// ```
+ /// acc = stack();
+ /// while (_pivot) {
+ /// acc = stack_push(acc, value);
+ /// }
+ /// ```
+ ///
+ /// We make sure that the stack push op in one iteration is executed before
+ /// next iteration. This is achieved by adding a control edge from
+ /// `forward_index.op.inputs[0].op` to the push op, and another control
+ /// edge from the push op to either `forward_index.op` or `forward_sync`.
+ ///
+ /// The source tensor in forward that is to be accumulated.
+ /// True iff the tensor is on a dead branch of a cond.
+ /// The stack that contains the accumulated history of the tensor.
+ public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false)
+ {
+ throw new NotImplementedException("AddForwardAccumulator");
+ // # curr_ctxt is the context that tf.gradients was called in.
+ // with self._forward_index.graph.as_default():
+ // curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
+ // with ops.control_dependencies(None):
+ // if curr_ctxt:
+ // curr_ctxt.Enter()
+ // with ops.colocate_with(value):
+ // # We only need to pass maximum_iterations to the stack if
+ // # we're inside an XLA context.
+ // if not util.IsInXLAContext(value.op):
+ // max_size = constant_op.constant(-1, dtypes.int32)
+ // else:
+ // max_size = GetMaxSizeFromNestedMaximumIterations(
+ // value, self.forward_context)
+ // acc = gen_data_flow_ops.stack_v2(
+ // max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
+ // if curr_ctxt:
+ // curr_ctxt.Exit()
+
+ // # Make acc available in the forward context.
+ // enter_acc = self.forward_context.AddValue(acc)
+
+ // # Add the stack_push op in the context of value.op.
+ // swap_enabled = self.forward_context.swap_memory
+ // value_ctxt = util.GetOutputContext(value.op)
+ // if value_ctxt == self.forward_context:
+ // # value is not nested in the forward context.
+ // self.forward_context.Enter()
+ // push = gen_data_flow_ops.stack_push_v2(
+ // enter_acc, value, swap_memory=swap_enabled)
+ // self.forward_context.Exit()
+ // # Protect stack push and order it before forward_index.
+ // self.forward_index.op._add_control_input(push.op)
+ // else:
+ // # value is in a cond context within the forward context.
+ // if not isinstance(value_ctxt, CondContext):
+ // raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
+ // if dead_branch:
+ // # The special case for creating a zero tensor for a dead
+ // # branch of a switch. See ControlFlowState.ZerosLike().
+ // value_ctxt.outer_context.Enter()
+ // push = gen_data_flow_ops.stack_push_v2(
+ // enter_acc, value, swap_memory=swap_enabled)
+ // value_ctxt.outer_context.Exit()
+ // push.op._set_control_flow_context(value_ctxt)
+ // else:
+ // value_ctxt.Enter()
+ // push = gen_data_flow_ops.stack_push_v2(
+ // enter_acc, value, swap_memory=swap_enabled)
+ // value_ctxt.Exit()
+ // # Protect stack push and order it before forward_sync.
+ // self.forward_sync._add_control_input(push.op)
+ // # Order stack push after the successor of forward_index
+ // add_op = self.forward_index.op.inputs[0].op
+ // push.op._add_control_input(add_op)
+ // return acc
+ }
+
+ // """Add the getter for an accumulated value in the grad context.
+ //
+ // This is added to the backprop loop. Called in the grad context to
+ // get the value of an accumulated value. The stack pop op must be guarded
+ // by the pred of the controlling cond.
+ //
+ // Args:
+ // history_value: The history (a stack) of a value.
+ // value: The value that is pushed onto the stack.
+ // dead_branch: True iff the tensor is on a dead branch of a cond.
+ //
+ // Returns:
+ // The current value (the top of the stack).
+ // """
+ public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false)
+ {
+ throw new NotImplementedException();
+ // history_ctxt = history_value.op._get_control_flow_context()
+ // # Find the cond context that controls history_value if any.
+ // cond_ctxt = None
+ // value_ctxt = value.op._get_control_flow_context()
+ // while value_ctxt and value_ctxt != history_ctxt:
+ // if isinstance(value_ctxt, CondContext):
+ // cond_ctxt = value_ctxt
+ // break
+ // value_ctxt = value_ctxt.outer_context
+ // with ops.control_dependencies(None):
+ // self.grad_context.Enter()
+ // if cond_ctxt:
+ // # Guard stack pop with a switch if it is controlled by a cond.
+ // grad_state = self
+ // pred = None
+ // while pred is None and grad_state:
+ // pred = grad_state.history_map.get(cond_ctxt.pred.name)
+ // grad_state = grad_state.outer_grad_state
+ // if pred is None:
+ // pred = cond_ctxt.pred
+ // branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
+ // history_value = _SwitchRefOrTensor(history_value, pred)[branch]
+ // pop = gen_data_flow_ops.stack_pop_v2(history_value,
+ // value.dtype.base_dtype)
+ // pop.set_shape(value.get_shape())
+ // self.grad_context.Exit()
+ // parallel_iterations = self.grad_context.parallel_iterations
+ // if parallel_iterations > 1:
+ // # All pops are ordered after pivot_for_body and before grad_sync.
+ // self.grad_sync._add_control_input(pop.op)
+ // return pop
+ }
+
+ // def GetRealValue(self, value):
+ // """Get the real value of `value`.
+
+ // If backprop "uses" a value produced by forward inference, an accumulator
+ // is added in the forward loop to accumulate its values. We use the
+ // accumulated value. This method must be called in the grad loop context.
+ // `value` must be in forward and needed for backprop.
+
+ // Args:
+ // value: A tensor to be captured.
+
+ // Returns:
+ // The same tensor obtained from the saved history.
+ // """
+ // assert value.op.type not in ["Variable", "VariableV2"]
+ // real_value = self._history_map.get(value.name)
+ // if real_value is None:
+ // cur_value = value
+ // cur_grad_state = self
+ // while True:
+ // enter_op = util.GetLoopConstantEnter(cur_value)
+ // if enter_op:
+ // # Special case: cur_value comes from a constant Enter node.
+ // cur_value = enter_op.inputs[0]
+ // cur_grad_state = cur_grad_state.outer_grad_state
+ // if cur_grad_state is None:
+ // # We are now outside all nested loops for this gradient(),
+ // # so `value` is a loop invariant and there is no need to
+ // # save the history of value. Just make cur_value to enter
+ // # the right control flow context.
+ // real_value = self._grad_context.AddValue(cur_value)
+ // break
+ // elif constant_op.is_constant(cur_value):
+ // # If the value to be forwarded is a constant, clone the constant in
+ // # the gradient loop rather than using a stack.
+ // # TODO(phawkins): consider hoisting the constant out of the loop
+ // # instead.
+ // real_value = constant_op.constant(
+ // tensor_util.constant_value(cur_value), dtype=cur_value.dtype)
+ // break
+ // else:
+ // # Record the history of this value in forward_ctxt.
+ // self._grad_context.Exit()
+ // history_value = cur_grad_state.AddForwardAccumulator(cur_value)
+ // self._grad_context.Enter()
+ // break
+
+ // if real_value is None:
+ // # Add the stack pop op in the grad context.
+ // real_value = cur_grad_state.AddBackpropAccumulatedValue(
+ // history_value, cur_value)
+ // if cur_grad_state != self:
+ // real_value = self._grad_context.AddValue(real_value)
+ // self._history_map[value.name] = real_value
+ // return real_value
+
+
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs
index 7fdd22f5..f9dde8c4 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs
@@ -4,13 +4,15 @@ using System.Text;
namespace Tensorflow
{
- public interface IControlFlowContext
- {
- void AddOp(Operation op);
- IControlFlowContext outer_context { get; }
- HashSet values { get; }
- Tensor AddValue(Tensor val);
- void AddInnerOp(Operation resultOp);
- object to_proto();
- }
+ // henon: this was too much trouble. there is no value just cost to use an interface here.
+ //public interface IControlFlowContext
+ //{
+ // void AddOp(Operation op);
+ // IControlFlowContext outer_context { get; }
+ // HashSet values { get; }
+ // Tensor pivot { get; }
+ // Tensor AddValue(Tensor val);
+ // void AddInnerOp(Operation resultOp);
+ // object to_proto();
+ //}
}
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
index d800679b..966ac83f 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
@@ -1,11 +1,26 @@
using System;
using System.Collections.Generic;
using System.Text;
+using Tensorflow.Operations.ControlFlows;
namespace Tensorflow.Operations
{
public class WhileContext : ControlFlowContext
{
+ private bool _back_prop=true;
+
+ private GradLoopState _grad_state =null;
+
+ public override WhileContext GetWhileContext()
+ {
+ return this;
+ }
+
+
+ public override GradLoopState grad_state => _grad_state;
+
+ public override bool back_prop => _back_prop;
+
public static WhileContext from_proto(object proto)
{
throw new NotImplementedException();
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
index 262d8e75..9b3aefe2 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
@@ -7,7 +7,7 @@ namespace Tensorflow
{
public partial class Operation
{
- private IControlFlowContext _control_flow_context;
+ private ControlFlowContext _control_flow_context;
///
/// Add this op to its control flow context.
@@ -39,12 +39,12 @@ namespace Tensorflow
_add_control_input(op);
}
- public void _set_control_flow_context(IControlFlowContext ctx)
+ public void _set_control_flow_context(ControlFlowContext ctx)
{
_control_flow_context = ctx;
}
- public IControlFlowContext _get_control_flow_context()
+ public ControlFlowContext _get_control_flow_context()
{
return _control_flow_context;
}
diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
index 11950b46..08b8c8b5 100644
--- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Operations;
+using Tensorflow.Operations.ControlFlows;
using util = Tensorflow.control_flow_util;
namespace Tensorflow
@@ -93,9 +94,9 @@ namespace Tensorflow
///
///
///
- public static object MaybeCreateControlFlowState(List between_op_list, List between_ops, bool colocate_gradients_with_ops)
+ public static ControlFlowState MaybeCreateControlFlowState(List between_op_list, List between_ops, bool colocate_gradients_with_ops)
{
- object loop_state = null;
+ ControlFlowState loop_state = null;
foreach (var op in between_op_list)
{
@@ -103,7 +104,7 @@ namespace Tensorflow
{
if(loop_state == null)
{
- // loop_state = ControlFlowState();
+ loop_state = new ControlFlowState();
}
}
}
@@ -207,7 +208,7 @@ namespace Tensorflow
/// `(output_false, output_true)`: If `pred` is true, data will be forwarded to
/// `output_true`, otherwise it goes to `output_false`.
///
- public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch")
+ public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch")
{
data = ops.convert_to_tensor_or_indexed_slices(data, name: "data");
// NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
@@ -298,7 +299,9 @@ namespace Tensorflow
*/
// Add the Switch to the graph.
- var (p_2, p_1) = @switch(pred, pred);
+ var switch_result= @switch(pred, pred);
+ var p_2=switch_result[0];
+ var p_1 = switch_result[1];
var pivot_1 = array_ops.identity(p_1, name: "switch_t");
var pivot_2 = array_ops.identity(p_2, name: "switch_f");
pred = array_ops.identity(pred, name: "pred_id");
@@ -379,7 +382,9 @@ namespace Tensorflow
return with(ops.name_scope(name, "cond", new { pred }), delegate
{
// Add the Switch to the graph.
- var (p_2, p_1) = @switch(pred, pred);
+ var switch_result = @switch(pred, pred);
+ var p_2 = switch_result[0];
+ var p_1 = switch_result[1];
var pivot_1 = array_ops.identity(p_1, name: "switch_t");
var pivot_2 = array_ops.identity(p_2, name: "switch_f");
pred = array_ops.identity(pred, name: "pred_id");
@@ -460,7 +465,7 @@ namespace Tensorflow
///
///
///
- public static (Tensor, Tensor) @switch(Tensor data,
+ public static Tensor[] @switch(Tensor data,
Tensor pred,
TF_DataType dtype = TF_DataType.DtInvalid,
string name = null)
diff --git a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
index 98ccbb06..5e2fc43e 100644
--- a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
+++ b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
@@ -30,7 +30,7 @@ namespace Tensorflow
///
/// Return the control flow context for the output of an op.
///
- public static IControlFlowContext GetOutputContext(Operation op)
+ public static ControlFlowContext GetOutputContext(Operation op)
{
var ctxt = op._get_control_flow_context();
// Exit nodes usually have a control flow context, except in the case where the
diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
index 31e2cad3..78e70053 100644
--- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
@@ -33,14 +33,14 @@ namespace Tensorflow
/// output_false: A `Tensor`. Has the same type as `data`.
/// output_true: A `Tensor`. Has the same type as `data`.
///
- public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null)
+ public static Tensor[] @switch(Tensor data, Tensor pred, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred });
var _inputs_flat = _op.inputs;
var _attrs = ("T", _op.get_attr("T"));
// TODO: missing original code
//_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name);
- return (_op.outputs[0], _op.outputs[1]);
+ return new []{_op.outputs[0], _op.outputs[1]};
}
public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null)