diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
new file mode 100644
index 00000000..af92e905
--- /dev/null
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
@@ -0,0 +1,105 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Tensorflow.Eager;
+
+namespace Tensorflow
+{
+ public partial class Graph
+ {
+ public Context _control_flow_context;
+
+ private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>();
+ public Queue<_ControlDependenciesController> _control_dependencies_stack
+ {
+ get
+ {
+ return _graph_control_dependencies_stack;
+ }
+ set
+ {
+ _graph_control_dependencies_stack = value;
+ }
+ }
+
+ ///
+ /// For an op that takes `input_ops` as inputs, compute control inputs.
+ ///
+ /// The data input ops for an op to be created.
+ /// A list of control inputs for the op to be created.
+ private Operation[] _control_dependencies_for_inputs(Operation[] input_ops)
+ {
+ Operation[] ret = new Operation[0];
+
+ foreach(var controller in _control_dependencies_stack)
+ {
+ bool dominated = false;
+ // If any of the input_ops already depends on the inputs from controller,
+ // we say that the new op is dominated (by that input), and we therefore
+ // do not need to add control dependencies for this controller's inputs.
+ foreach(var op in input_ops)
+ {
+ if (controller.op_in_group(op))
+ {
+ dominated = true;
+ break;
+ }
+ }
+
+ if (!dominated)
+ ret = controller.control_inputs.Where(x => !input_ops.Contains(x)).ToArray();
+ }
+
+ return ret;
+ }
+
+ public _ControlDependenciesController control_dependencies(Operation[] control_inputs)
+ {
+ if (control_inputs == null)
+ return new _ControlDependenciesController(this, null);
+
+ var control_ops = new List();
+ foreach (var c in control_inputs)
+ {
+ control_ops.Add(c);
+ }
+
+ return new _ControlDependenciesController(this, control_ops);
+ }
+
+ ///
+ /// Returns the current control flow context.
+ ///
+ /// A context object.
+ public Context _get_control_flow_context()
+ {
+ return _control_flow_context;
+ }
+
+ ///
+ /// Sets the current control flow context.
+ ///
+ /// a context object.
+ public void _set_control_flow_context(Context ctx)
+ {
+ _control_flow_context = ctx;
+ }
+
+ public void _push_control_dependencies_controller(_ControlDependenciesController controller)
+ {
+ _control_dependencies_stack.Enqueue(controller);
+ }
+
+ public void _pop_control_dependencies_controller(_ControlDependenciesController controller)
+ {
+ _control_dependencies_stack.Dequeue();
+ }
+
+ public void _record_op_seen_by_control_dependencies(Operation op)
+ {
+ foreach (var controller in _control_dependencies_stack)
+ controller.add_op(op);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index 963a1a40..3d46196d 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -142,19 +142,9 @@ namespace Tensorflow
return op;
}
- ///
- /// For an op that takes `input_ops` as inputs, compute control inputs.
- ///
- /// The data input ops for an op to be created.
- /// A list of control inputs for the op to be created.
- private Operation[] _control_dependencies_for_inputs(Operation[] input_ops)
- {
- return new Operation[0];
- }
-
private void _create_op_helper(Operation op, bool compute_device = true)
{
-
+ _record_op_seen_by_control_dependencies(op);
}
public void _add_op(Operation op)
diff --git a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
new file mode 100644
index 00000000..08302cc1
--- /dev/null
+++ b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
@@ -0,0 +1,80 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Eager;
+
+namespace Tensorflow
+{
+ ///
+ /// Context manager for `control_dependencies()`
+ ///
+ public class _ControlDependenciesController : IPython
+ {
+ private Graph _graph;
+ private List _control_inputs_val;
+ private List _seen_nodes;
+ private Queue<_ControlDependenciesController> _old_stack;
+ private bool _new_stack;
+ private Context _old_control_flow_context;
+
+ public Operation[] control_inputs => _control_inputs_val.ToArray();
+
+ public _ControlDependenciesController(Graph graph, List control_inputs)
+ {
+ _graph = graph;
+ if (control_inputs == null)
+ {
+ _control_inputs_val = new List();
+ _new_stack = true;
+ }
+ else
+ {
+ _control_inputs_val = control_inputs;
+ _new_stack = false;
+ }
+
+ _seen_nodes = new List();
+ }
+
+ public void add_op(Operation op)
+ {
+ _seen_nodes.Add(op);
+ }
+
+ public bool op_in_group(Operation op)
+ {
+ return _seen_nodes.Contains(op);
+ }
+
+ public void __enter__()
+ {
+ if (_new_stack)
+ {
+ // Clear the control_dependencies graph.
+ _old_stack = _graph._control_dependencies_stack;
+ _graph._control_dependencies_stack = new Queue<_ControlDependenciesController>();
+
+ // Clear the control_flow_context too.
+ _old_control_flow_context = _graph._get_control_flow_context();
+ _graph._set_control_flow_context(null);
+ }
+
+ _graph._push_control_dependencies_controller(this);
+ }
+
+ public void __exit__()
+ {
+ _graph._pop_control_dependencies_controller(this);
+ if (_new_stack)
+ {
+ _graph._control_dependencies_stack = _old_stack;
+ _graph._set_control_flow_context(_old_control_flow_context);
+ }
+ }
+
+ public void Dispose()
+ {
+
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs
index 55987262..5db34ce9 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs
@@ -39,6 +39,14 @@ namespace Tensorflow
public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);
+ public Operation[] control_inputs
+ {
+ get
+ {
+ return GetControlInputs();
+ }
+ }
+
public unsafe Operation[] GetControlInputs()
{
var control_inputs = new Operation[NumControlInputs];
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index ef5d9813..dfc89e67 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -49,15 +49,53 @@ namespace Tensorflow
c_api.TF_FinishOperation(desc, status);
}
- public Operation(NodeDef node_def, Graph g, List inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
+ ///
+ /// Creates an `Operation`.
+ ///
+ /// `node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.
+ /// `Graph`. The parent graph.
+ /// list of `Tensor` objects. The inputs to this `Operation`.
+ /// list of `DType` objects.
+ ///
+ /// list of operations or tensors from which to have a
+ /// control dependency.
+ ///
+ ///
+ /// List of `DType` objects representing the
+ /// types of the tensors accepted by the `Operation`. By default
+ /// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect
+ /// reference-typed inputs must specify these explicitly.
+ ///
+ ///
+ ///
+ public Operation(NodeDef node_def, Graph g, List inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
{
Graph = g;
+ // Build the list of control inputs.
+ var control_input_ops = new List();
+ if(control_inputs != null)
+ {
+ foreach(var c in control_inputs)
+ {
+ switch (c)
+ {
+ case Operation c1:
+ control_input_ops.Add(c1);
+ break;
+ default:
+ throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
+ }
+ }
+ }
+
+ // This will be set by self.inputs.
+
_id_value = Graph._next_id();
if(op_def == null)
op_def = g.GetOpDef(node_def.Op);
- _handle = ops._create_c_op(g, node_def, inputs);
+ _handle = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray());
output_types = new TF_DataType[NumOutputs];
diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs
index 46317bac..0283e013 100644
--- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs
+++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs
@@ -34,6 +34,14 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TF_AddInput(IntPtr desc, TF_Output input);
+ ///
+ /// Call once per control input to `desc`.
+ ///
+ /// TF_OperationDescription*
+ /// TF_Operation*
+ [DllImport(TensorFlowLibName)]
+ public static extern void TF_AddControlInput(IntPtr desc, IntPtr input);
+
///
/// For inputs that take a list of tensors.
/// inputs must point to TF_Output[num_inputs].
diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
index 37b91448..6146e54d 100644
--- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
@@ -13,9 +13,8 @@ namespace Tensorflow
{
name = namescope;
- var ops_on_device = new Dictionary();
-
// Sorts *inputs according to their devices.
+ var ops_on_device = new Dictionary();
foreach (var inp in inputs)
{
ops_on_device[inp.Device] = new Operation[] { inp };
@@ -24,7 +23,9 @@ namespace Tensorflow
// 1-level tree. The root node is the returned NoOp node.
if (ops_on_device.Count == 1)
{
- return _GroupControlDeps(ops_on_device.Keys.First(), ops_on_device.Values.First(), name);
+ var dev = ops_on_device.Keys.First();
+ var deps = ops_on_device.Values.First();
+ return _GroupControlDeps(dev, deps, name);
}
// 2-level tree. The root node is the returned NoOp node.
@@ -35,12 +36,21 @@ namespace Tensorflow
private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "")
{
- if (string.IsNullOrEmpty(dev))
+ Operation result = null;
+
+ Python.with(ops.control_dependencies(deps), delegate
{
- return gen_control_flow_ops.no_op(name);
- }
+ if (string.IsNullOrEmpty(dev))
+ {
+ result = gen_control_flow_ops.no_op(name);
+ }
+ else
+ {
+ result = gen_control_flow_ops.no_op(name);
+ }
+ });
- return null;
+ return result;
}
}
}
diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs
index 2e4d6423..9059be1e 100644
--- a/src/TensorFlowNET.Core/Python.cs
+++ b/src/TensorFlowNET.Core/Python.cs
@@ -13,5 +13,30 @@ namespace Tensorflow
{
Console.WriteLine(obj.ToString());
}
+
+ public static void with(IPython py, Action action)
+ {
+ try
+ {
+ py.__enter__();
+ action();
+ }
+ catch (Exception ex)
+ {
+ throw ex;
+ }
+ finally
+ {
+ py.__exit__();
+ py.Dispose();
+ }
+ }
+ }
+
+ public interface IPython : IDisposable
+ {
+ void __enter__();
+
+ void __exit__();
}
}
diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs
index 6e4d28f5..5e7fe2a5 100644
--- a/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs
+++ b/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs
@@ -20,5 +20,10 @@ namespace Tensorflow
{
return var._AsTensor();
}
+
+ public static implicit operator RefVariable(Tensor var)
+ {
+ return null;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs
index fc15d97d..5736dd23 100644
--- a/src/TensorFlowNET.Core/Variables/RefVariable.cs
+++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs
@@ -166,5 +166,10 @@ namespace Tensorflow
// Recursively build initializer expressions for inputs.
return op;
}
+
+ public override string ToString()
+ {
+ return $"tf.Variable '{name}' shape={shape} dtype={dtype}";
+ }
}
}
diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs
index b8c8dc1c..3440298c 100644
--- a/src/TensorFlowNET.Core/ops.py.cs
+++ b/src/TensorFlowNET.Core/ops.py.cs
@@ -78,7 +78,29 @@ namespace Tensorflow
}
}
- public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List inputs)
+ ///
+ /// Wrapper for `Graph.control_dependencies()` using the default graph.
+ ///
+ ///
+ public static _ControlDependenciesController control_dependencies(Operation[] control_inputs)
+ {
+ return get_default_graph().control_dependencies(control_inputs);
+ }
+
+ ///
+ /// Creates a TF_Operation.
+ ///
+ /// a `Graph`.
+ /// `node_def_pb2.NodeDef` for the operation to create.
+ ///
+ /// A list of `Tensor`s (corresponding to scalar inputs) and lists of
+ /// `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N",
+ /// "list(int64)"). The length of the list should be equal to the number of
+ /// inputs specified by this operation's op def.
+ ///
+ /// A list of `Operation`s to set as control dependencies.
+ /// A wrapped TF_Operation*.
+ public static IntPtr _create_c_op(Graph graph, NodeDef node_def, List inputs, Operation[] control_inputs)
{
var op_desc = graph.NewOperation(node_def.Op, node_def.Name);
@@ -102,6 +124,8 @@ namespace Tensorflow
var status = new Status();
// Add control inputs
+ foreach (var control_input in control_inputs)
+ c_api.TF_AddControlInput(op_desc, control_input);
// Add attrs
foreach (var attr in node_def.Attr)
@@ -170,8 +194,11 @@ namespace Tensorflow
// inner_device_stack = default_graph._device_function_stack
// var outer_context = default_graph.as_default;
- var outer_graph = get_default_graph();
- // outer_device_stack = None
+ Python.with(ops.control_dependencies(null), delegate
+ {
+ var outer_graph = get_default_graph();
+ // outer_device_stack = None
+ });
}
private static int uid_number = 0;
diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs
index 69e4f39e..95313d4c 100644
--- a/test/TensorFlowNET.UnitTest/VariableTest.cs
+++ b/test/TensorFlowNET.UnitTest/VariableTest.cs
@@ -46,14 +46,13 @@ namespace TensorFlowNET.UnitTest
var x = tf.Variable(10, name: "x");
var model = tf.global_variables_initializer();
-
using (var session = tf.Session())
{
- session.run(x.initializer);
+ session.run(model);
for(int i = 0; i < 5; i++)
{
- var x1 = x + 1;
- var result = session.run(x1);
+ x = x + 1;
+ var result = session.run(x);
print(result);
}
}